From 86f7f245e4331d6025496a27a80fd4bcc72e0802 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:07:06 +0800 Subject: [PATCH 01/13] fix: The length of the tag should between 1 and 50 (#8187) (#8188) --- api/controllers/console/tag/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 7293aeeb34cef9..de30547e93f937 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -13,7 +13,7 @@ def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: + if not name or len(name) < 1 or len(name) > 50: raise ValueError("Name must be between 1 to 50 characters.") return name From af92f19291cc407ae6b70c3fc0b73276d5edbd47 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:55:08 +0800 Subject: [PATCH 02/13] filter excel empty sheet (#8194) --- api/core/rag/extractor/excel_extractor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index f0c302a6197a64..526c66042ca244 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -38,7 +38,10 @@ def extract(self) -> list[Document]: for sheet_name in wb.sheetnames: sheet = wb[sheet_name] data = sheet.values - cols = next(data) + try: + cols = next(data) + except StopIteration: + continue df = pd.DataFrame(data, columns=cols) df.dropna(how='all', inplace=True) From ed37439ef791e68964f5eaed90d8df41934787ad Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 10 Sep 2024 15:00:25 +0800 Subject: [PATCH 03/13] refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166) --- api/core/agent/base_agent_runner.py | 239 +++++++++++++++------------- api/core/model_manager.py | 157 ++++++++---------- 2 files changed, 199 insertions(+), 197 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index d8290ca608b0cb..d09a9956a4a591 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Optional, Union, cast @@ -45,22 +46,25 @@ logger = logging.getLogger(__name__) + class BaseAgentRunner(AppRunner): - def __init__(self, tenant_id: str, - application_generate_entity: AgentChatAppGenerateEntity, - conversation: Conversation, - app_config: AgentChatAppConfig, - model_config: ModelConfigWithCredentialsEntity, - config: AgentEntity, - queue_manager: AppQueueManager, - message: Message, - user_id: str, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, - variables_pool: Optional[ToolRuntimeVariablePool] = None, - db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None - ) -> None: + def __init__( + self, + tenant_id: str, + application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + config: AgentEntity, + queue_manager: AppQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[list[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance = None, + ) -> None: """ Agent runner :param tenant_id: tenant id @@ -88,9 +92,7 @@ def __init__(self, tenant_id: str, self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = self.organize_agent_history( - prompt_messages=prompt_messages or [] - ) + self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -111,12 +113,16 @@ def __init__(self, tenant_id: str, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # get how many agent thoughts have been created - self.agent_thought_count = db.session.query(MessageAgentThought).filter( - MessageAgentThought.message_id == self.message.id, - ).count() + self.agent_thought_count = ( + db.session.query(MessageAgentThought) + .filter( + MessageAgentThought.message_id == self.message.id, + ) + .count() + ) db.session.close() # check if model supports stream tool call @@ -135,25 +141,26 @@ def __init__(self, tenant_id: str, self.query = None self._current_thoughts: list[PromptMessage] = [] - def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ - -> AgentChatAppGenerateEntity: + def _repack_app_generate_entity( + self, app_generate_entity: AgentChatAppGenerateEntity + ) -> AgentChatAppGenerateEntity: """ Repack app generate entity """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: - app_generate_entity.app_config.prompt_template.simple_prompt_template = '' + app_generate_entity.app_config.prompt_template.simple_prompt_template = "" return app_generate_entity - + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ - convert tool to prompt message tool + convert tool to prompt message tool """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, - invoke_from=self.application_generate_entity.invoke_from + invoke_from=self.application_generate_entity.invoke_from, ) tool_entity.load_variables(self.variables_pool) @@ -164,7 +171,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P "type": "object", "properties": {}, "required": [], - } + }, ) parameters = tool_entity.get_all_runtime_parameters() @@ -177,19 +184,19 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - message_tool.parameters['properties'][parameter.name] = { + message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + message_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - message_tool.parameters['required'].append(parameter.name) + message_tool.parameters["required"].append(parameter.name) return message_tool, tool_entity - + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: """ convert dataset retriever tool to prompt message tool @@ -201,24 +208,24 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe "type": "object", "properties": {}, "required": [], - } + }, ) for parameter in tool.get_runtime_parameters(): - parameter_type = 'string' - - prompt_tool.parameters['properties'][parameter.name] = { + parameter_type = "string" + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + + def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: """ Init tools """ @@ -261,51 +268,51 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - - prompt_tool.parameters['properties'][parameter.name] = { + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + prompt_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + + def create_agent_thought( + self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] + ) -> MessageAgentThought: """ Create agent thought """ thought = MessageAgentThought( message_id=message_id, message_chain_id=None, - thought='', + thought="", tool=tool_name, - tool_labels_str='{}', - tool_meta_str='{}', + tool_labels_str="{}", + tool_meta_str="{}", tool_input=tool_input, message=message, message_token=0, message_unit_price=0, message_price_unit=0, - message_files=json.dumps(messages_ids) if messages_ids else '', - answer='', - observation='', + message_files=json.dumps(messages_ids) if messages_ids else "", + answer="", + observation="", answer_token=0, answer_unit_price=0, answer_price_unit=0, tokens=0, total_price=0, position=self.agent_thought_count + 1, - currency='USD', + currency="USD", latency=0, - created_by_role='account', + created_by_role="account", created_by=self.user_id, ) @@ -318,22 +325,22 @@ def create_agent_thought(self, message_id: str, message: str, return thought - def save_agent_thought(self, - agent_thought: MessageAgentThought, - tool_name: str, - tool_input: Union[str, dict], - thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], - answer: str, - messages_ids: list[str], - llm_usage: LLMUsage = None) -> MessageAgentThought: + def save_agent_thought( + self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: Union[str, dict], + tool_invoke_meta: Union[str, dict], + answer: str, + messages_ids: list[str], + llm_usage: LLMUsage = None, + ) -> MessageAgentThought: """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter( - MessageAgentThought.id == agent_thought.id - ).first() + agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() if thought is not None: agent_thought.thought = thought @@ -356,7 +363,7 @@ def save_agent_thought(self, observation = json.dumps(observation, ensure_ascii=False) except Exception as e: observation = json.dumps(observation) - + agent_thought.observation = observation if answer is not None: @@ -364,7 +371,7 @@ def save_agent_thought(self, if messages_ids is not None and len(messages_ids) > 0: agent_thought.message_files = json.dumps(messages_ids) - + if llm_usage: agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_price_unit = llm_usage.prompt_price_unit @@ -377,7 +384,7 @@ def save_agent_thought(self, # check if tool labels is not empty labels = agent_thought.tool_labels or {} - tools = agent_thought.tool.split(';') if agent_thought.tool else [] + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -386,7 +393,7 @@ def save_agent_thought(self, if tool_label: labels[tool] = tool_label.to_dict() else: - labels[tool] = {'en_US': tool, 'zh_Hans': tool} + labels[tool] = {"en_US": tool, "zh_Hans": tool} agent_thought.tool_labels_str = json.dumps(labels) @@ -401,14 +408,18 @@ def save_agent_thought(self, db.session.commit() db.session.close() - + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables """ - db_variables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == self.message.conversation_id, - ).first() + db_variables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ) + .first() + ) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) @@ -425,9 +436,14 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.message.conversation_id, - ).order_by(Message.created_at.asc()).all() + messages: list[Message] = ( + db.session.query(Message) + .filter( + Message.conversation_id == self.message.conversation_id, + ) + .order_by(Message.created_at.asc()) + .all() + ) for message in messages: if message.id == self.message.id: @@ -439,13 +455,13 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P for agent_thought in agent_thoughts: tools = agent_thought.tool if tools: - tools = tools.split(';') + tools = tools.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: - tool_inputs = { tool: {} for tool in tools } + tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) except Exception as e: @@ -454,27 +470,33 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ), + ) + ) + tool_call_response.append( + ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), + tool_call_id=tool_call_id, ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) - - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + ) + + result.extend( + [ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response, + ] + ) if not tools: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: @@ -496,10 +518,7 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7b1a7ada5b2225..990efd36c609c2 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging import os -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -41,7 +41,7 @@ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> No configuration=provider_model_bundle.configuration, model_type=provider_model_bundle.model_type_instance.model_type, model=model, - credentials=self.credentials + credentials=self.credentials, ) @staticmethod @@ -54,10 +54,7 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m """ configuration = provider_model_bundle.configuration model_type = provider_model_bundle.model_type_instance.model_type - credentials = configuration.get_current_credentials( - model_type=model_type, - model=model - ) + credentials = configuration.get_current_credentials(model_type=model_type, model=model) if credentials is None: raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") @@ -65,10 +62,9 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m return credentials @staticmethod - def _get_load_balancing_manager(configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict) -> Optional["LBModelManager"]: + def _get_load_balancing_manager( + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + ) -> Optional["LBModelManager"]: """ Get load balancing model credentials :param configuration: provider configuration @@ -81,8 +77,7 @@ def _get_load_balancing_manager(configuration: ProviderConfiguration, current_model_setting = None # check if model is disabled by admin for model_setting in configuration.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: current_model_setting = model_setting break @@ -95,17 +90,23 @@ def _get_load_balancing_manager(configuration: ProviderConfiguration, model_type=model_type, model=model, load_balancing_configs=current_model_setting.load_balancing_configs, - managed_credentials=credentials if configuration.custom_configuration.provider else None + managed_credentials=credentials if configuration.custom_configuration.provider else None, ) return lb_model_manager return None - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -132,11 +133,12 @@ def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Opt stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) - def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_llm_num_tokens( + self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Get number of tokens for llm @@ -153,11 +155,10 @@ def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, - tools=tools + tools=tools, ) - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult: """ Invoke large language model @@ -174,7 +175,7 @@ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, texts=texts, - user=user + user=user, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -192,13 +193,17 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: function=self.model_type_instance.get_num_tokens, model=self.model, credentials=self.credentials, - texts=texts + texts=texts, ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke_rerank( + self, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -221,11 +226,10 @@ def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[f docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user + user=user, ) - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: + def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -242,11 +246,10 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) \ model=self.model, credentials=self.credentials, text=text, - user=user + user=user, ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -263,11 +266,10 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, file=file, - user=user + user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ - -> str: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: """ Invoke large language tts model @@ -288,7 +290,7 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option content_text=content_text, user=user, tenant_id=tenant_id, - voice=voice + voice=voice, ) def _round_robin_invoke(self, function: Callable, *args, **kwargs): @@ -312,8 +314,8 @@ def _round_robin_invoke(self, function: Callable, *args, **kwargs): raise last_exception try: - if 'credentials' in kwargs: - del kwargs['credentials'] + if "credentials" in kwargs: + del kwargs["credentials"] return function(*args, **kwargs, credentials=lb_config.credentials) except InvokeRateLimitError as e: # expire in 60 seconds @@ -340,9 +342,7 @@ def get_tts_voices(self, language: Optional[str] = None) -> list: self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language + model=self.model, credentials=self.credentials, language=language ) @@ -363,9 +363,7 @@ def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelTyp return self.get_default_model_instance(tenant_id, model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=provider, - model_type=model_type + tenant_id=tenant_id, provider=provider, model_type=model_type ) return ModelInstance(provider_model_bundle, model) @@ -386,10 +384,7 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M :param model_type: model type :return: """ - default_model_entity = self._provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type - ) + default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type) if not default_model_entity: raise ProviderTokenNotInitError(f"Default model not found for {model_type}") @@ -398,17 +393,20 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M tenant_id=tenant_id, provider=default_model_entity.provider.provider, model_type=model_type, - model=default_model_entity.model + model=default_model_entity.model, ) class LBModelManager: - def __init__(self, tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None) -> None: + def __init__( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None, + ) -> None: """ Load balancing model manager :param tenant_id: tenant_id @@ -439,10 +437,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: :return: """ cache_key = "model_lb_index:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model + self._tenant_id, self._provider, self._model_type.value, self._model ) cooldown_load_balancing_configs = [] @@ -473,10 +468,12 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: continue - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" - f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" - f"model_type: {self._model_type.value}\nmodel: {self._model}") + if bool(os.environ.get("DEBUG", "False").lower() == "true"): + logger.info( + f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}" + ) return config @@ -490,14 +487,10 @@ def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - redis_client.setex(cooldown_cache_key, expire, 'true') + redis_client.setex(cooldown_cache_key, expire, "true") def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: """ @@ -506,11 +499,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) res = redis_client.exists(cooldown_cache_key) @@ -518,11 +507,9 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: return res @staticmethod - def get_config_in_cooldown_and_ttl(tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - config_id: str) -> tuple[bool, int]: + def get_config_in_cooldown_and_ttl( + tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str + ) -> tuple[bool, int]: """ Get model load balancing config is in cooldown and ttl :param tenant_id: workspace id @@ -533,11 +520,7 @@ def get_config_in_cooldown_and_ttl(tenant_id: str, :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - tenant_id, - provider, - model_type.value, - model, - config_id + tenant_id, provider, model_type.value, model, config_id ) ttl = redis_client.ttl(cooldown_cache_key) From 5da0182800f5900e0cc2812754ede12402fd8972 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:02:52 +0800 Subject: [PATCH 04/13] docs: replace docker-compose with docker compose (#8195) --- docker/certbot/README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docker/certbot/README.md b/docker/certbot/README.md index c6f73ae699df15..21be34b33adde2 100644 --- a/docker/certbot/README.md +++ b/docker/certbot/README.md @@ -2,8 +2,8 @@ ## Short description -Docker-compose certbot configurations with Backward compatibility (without certbot container). -Use `docker-compose --profile certbot up` to use this features. +docker compose certbot configurations with Backward compatibility (without certbot container). +Use `docker compose --profile certbot up` to use this features. ## The simplest way for launching new servers with SSL certificates @@ -18,21 +18,21 @@ Use `docker-compose --profile certbot up` to use this features. ``` execute command: ```shell - sudo docker network prune - sudo docker-compose --profile certbot up --force-recreate -d + docker network prune + docker compose --profile certbot up --force-recreate -d ``` then after the containers launched: ```shell - sudo docker-compose exec -it certbot /bin/sh /update-cert.sh + docker compose exec -it certbot /bin/sh /update-cert.sh ``` -2. Edit `.env` file and `sudo docker-compose --profile certbot up` again. +2. Edit `.env` file and `docker compose --profile certbot up` again. set `.env` value additionally ```properties NGINX_HTTPS_ENABLED=true ``` execute command: ```shell - sudo docker-compose --profile certbot up -d --no-deps --force-recreate nginx + docker compose --profile certbot up -d --no-deps --force-recreate nginx ``` Then you can access your serve with HTTPS. [https://your_domain.com](https://your_domain.com) @@ -42,8 +42,8 @@ Use `docker-compose --profile certbot up` to use this features. For SSL certificates renewal, execute commands below: ```shell -sudo docker-compose exec -it certbot /bin/sh /update-cert.sh -sudo docker-compose exec nginx nginx -s reload +docker compose exec -it certbot /bin/sh /update-cert.sh +docker compose exec nginx nginx -s reload ``` ## Options for certbot @@ -57,14 +57,14 @@ CERTBOT_OPTIONS=--dry-run To apply changes to `CERTBOT_OPTIONS`, regenerate the certbot container before updating the certificates. ```shell -sudo docker-compose --profile certbot up -d --no-deps --force-recreate certbot -sudo docker-compose exec -it certbot /bin/sh /update-cert.sh +docker compose --profile certbot up -d --no-deps --force-recreate certbot +docker compose exec -it certbot /bin/sh /update-cert.sh ``` Then, reload the nginx container if necessary. ```shell -sudo docker-compose exec nginx nginx -s reload +docker compose exec nginx nginx -s reload ``` ## For legacy servers @@ -72,5 +72,5 @@ sudo docker-compose exec nginx nginx -s reload To use cert files dir `nginx/ssl` as before, simply launch containers WITHOUT `--profile certbot` option. ```shell -sudo docker-compose up -d +docker compose up -d ``` From dabfd74622a613c7198f06791ad70424ac94f54f Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Sep 2024 15:23:16 +0800 Subject: [PATCH 05/13] feat: Parallel Execution of Nodes in Workflows (#8192) Co-authored-by: StyleZhang Co-authored-by: Yi Co-authored-by: -LAN- --- api/configs/packaging/__init__.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 169 +-- api/core/app/apps/advanced_chat/app_runner.py | 290 ++--- .../advanced_chat/generate_task_pipeline.py | 662 ++++------- .../workflow_event_trigger_callback.py | 203 ---- .../base_app_generate_response_converter.py | 2 +- api/core/app/apps/base_app_runner.py | 6 +- api/core/app/apps/workflow/app_generator.py | 65 +- api/core/app/apps/workflow/app_runner.py | 145 ++- .../apps/workflow/generate_task_pipeline.py | 433 +++---- .../workflow_event_trigger_callback.py | 200 ---- api/core/app/apps/workflow_app_runner.py | 379 +++++++ .../app/apps/workflow_logging_callback.py | 284 +++-- api/core/app/entities/queue_entities.py | 182 ++- api/core/app/entities/task_entities.py | 156 +-- .../based_generate_task_pipeline.py | 14 +- .../app/task_pipeline/message_cycle_manage.py | 36 +- .../task_pipeline/workflow_cycle_manage.py | 665 ++++++----- .../workflow_cycle_state_manager.py | 16 - .../workflow_iteration_cycle_manage.py | 290 ----- .../model_runtime/entities/llm_entities.py | 33 + api/core/moderation/output_moderation.py | 8 +- api/core/tools/tool/workflow_tool.py | 4 +- api/core/tools/tool_engine.py | 2 + api/core/tools/tool_manager.py | 3 +- api/core/tools/utils/message_transformer.py | 1 + .../callbacks/base_workflow_callback.py | 113 +- .../entities/base_node_data_entities.py | 2 +- api/core/workflow/entities/node_entities.py | 34 +- api/core/workflow/entities/variable_pool.py | 86 +- .../workflow/entities/workflow_entities.py | 5 +- api/core/workflow/errors.py | 10 +- api/core/workflow/graph_engine/__init__.py | 0 .../condition_handlers/__init__.py | 0 .../condition_handlers/base_handler.py | 31 + .../branch_identify_handler.py | 28 + .../condition_handlers/condition_handler.py | 32 + .../condition_handlers/condition_manager.py | 35 + .../graph_engine/entities/__init__.py | 0 .../workflow/graph_engine/entities/event.py | 163 +++ .../workflow/graph_engine/entities/graph.py | 692 ++++++++++++ .../entities/graph_init_params.py | 21 + .../entities/graph_runtime_state.py | 27 + .../graph_engine/entities/next_graph_node.py | 13 + .../graph_engine/entities/run_condition.py | 21 + .../entities/runtime_route_state.py | 111 ++ .../workflow/graph_engine/graph_engine.py | 716 ++++++++++++ api/core/workflow/nodes/answer/answer_node.py | 91 +- .../answer/answer_stream_generate_router.py | 169 +++ .../nodes/answer/answer_stream_processor.py | 221 ++++ .../nodes/answer/base_stream_processor.py | 71 ++ api/core/workflow/nodes/answer/entities.py | 42 +- api/core/workflow/nodes/base_node.py | 196 +--- api/core/workflow/nodes/code/code_node.py | 24 +- api/core/workflow/nodes/end/end_node.py | 62 +- .../nodes/end/end_stream_generate_router.py | 148 +++ .../nodes/end/end_stream_processor.py | 191 ++++ api/core/workflow/nodes/end/entities.py | 16 + api/core/workflow/nodes/event.py | 20 + .../nodes/http_request/http_request_node.py | 28 +- api/core/workflow/nodes/if_else/entities.py | 15 +- .../workflow/nodes/if_else/if_else_node.py | 405 +------ api/core/workflow/nodes/iteration/entities.py | 9 +- .../nodes/iteration/iteration_node.py | 429 +++++-- .../nodes/iteration/iteration_start_node.py | 39 + .../knowledge_retrieval_node.py | 32 +- api/core/workflow/nodes/llm/llm_node.py | 166 ++- api/core/workflow/nodes/loop/loop_node.py | 30 +- api/core/workflow/nodes/node_mapping.py | 37 + .../parameter_extractor_node.py | 35 +- .../question_classifier_node.py | 50 +- api/core/workflow/nodes/start/start_node.py | 22 +- .../template_transform_node.py | 26 +- api/core/workflow/nodes/tool/tool_node.py | 23 +- .../variable_aggregator_node.py | 25 +- .../workflow/nodes/variable_assigner/node.py | 13 +- api/core/workflow/utils/condition/__init__.py | 0 api/core/workflow/utils/condition/entities.py | 17 + .../workflow/utils/condition/processor.py | 383 +++++++ api/core/workflow/workflow_engine_manager.py | 1005 ----------------- api/core/workflow/workflow_entry.py | 314 +++++ ...21501b_add_node_execution_id_into_node_.py | 35 + api/models/workflow.py | 3 + api/services/app_dsl_service.py | 3 +- api/services/app_generate_service.py | 7 +- api/services/workflow_service.py | 160 ++- .../workflow/nodes/test_code.py | 319 +++--- .../workflow/nodes/test_http.py | 156 ++- .../workflow/nodes/test_llm.py | 133 ++- .../nodes/test_parameter_extractor.py | 192 ++-- .../workflow/nodes/test_template_transform.py | 84 +- .../workflow/nodes/test_tool.py | 84 +- api/tests/unit_tests/conftest.py | 17 + .../core/workflow/graph_engine/__init__.py | 0 .../core/workflow/graph_engine/test_graph.py | 791 +++++++++++++ .../graph_engine/test_graph_engine.py | 505 +++++++++ .../core/workflow/nodes/answer/__init__.py | 0 .../core/workflow/nodes/answer/test_answer.py | 82 ++ .../test_answer_stream_generate_router.py | 109 ++ .../answer/test_answer_stream_processor.py | 216 ++++ .../core/workflow/nodes/iteration/__init__.py | 0 .../nodes/iteration/test_iteration.py | 420 +++++++ .../core/workflow/nodes/test_answer.py | 65 +- .../core/workflow/nodes/test_if_else.py | 126 ++- .../workflow/nodes/test_variable_assigner.py | 205 +++- docker-legacy/docker-compose.yaml | 6 +- docker/docker-compose.yaml | 6 +- .../chat/chat/answer/workflow-process.tsx | 18 +- web/app/components/base/chat/chat/hooks.ts | 21 +- .../share/text-generation/result/index.tsx | 22 +- .../components/workflow/candidate-node.tsx | 4 + web/app/components/workflow/constants.ts | 21 +- .../workflow/hooks/use-nodes-interactions.ts | 235 ++-- .../workflow/hooks/use-workflow-run.ts | 113 +- .../workflow/hooks/use-workflow-template.ts | 6 +- .../components/workflow/hooks/use-workflow.ts | 58 +- web/app/components/workflow/index.tsx | 5 + web/app/components/workflow/limit-tips.tsx | 39 + .../nodes/_base/components/next-step/add.tsx | 24 +- .../_base/components/next-step/container.tsx | 55 + .../_base/components/next-step/index.tsx | 88 +- .../nodes/_base/components/next-step/item.tsx | 88 +- .../nodes/_base/components/next-step/line.tsx | 90 +- .../_base/components/next-step/operator.tsx | 129 +++ .../nodes/_base/components/node-handle.tsx | 42 +- .../nodes/_base/components/node-resizer.tsx | 4 +- .../components/workflow/nodes/_base/node.tsx | 10 + .../workflow/nodes/if-else/types.ts | 1 - .../workflow/nodes/if-else/utils.ts | 1 - .../nodes/iteration-start/constants.ts | 1 + .../workflow/nodes/iteration-start/default.ts | 21 + .../workflow/nodes/iteration-start/index.tsx | 42 + .../workflow/nodes/iteration-start/types.ts | 3 + .../workflow/nodes/iteration/add-block.tsx | 95 +- .../workflow/nodes/iteration/default.ts | 1 + .../workflow/nodes/iteration/insert-block.tsx | 61 - .../workflow/nodes/iteration/node.tsx | 20 +- .../nodes/iteration/use-interactions.ts | 6 +- .../workflow/operator/add-block.tsx | 2 +- web/app/components/workflow/operator/hooks.ts | 2 +- .../workflow/panel/debug-and-preview/hooks.ts | 107 +- web/app/components/workflow/run/index.tsx | 23 +- .../workflow/run/iteration-result-panel.tsx | 78 +- web/app/components/workflow/run/node.tsx | 53 +- .../components/workflow/run/tracing-panel.tsx | 264 ++++- web/app/components/workflow/store.ts | 4 + web/app/components/workflow/types.ts | 3 +- web/app/components/workflow/utils.ts | 311 ++++- web/i18n/en-US/workflow.ts | 17 +- web/i18n/zh-Hans/workflow.ts | 17 +- web/package.json | 2 +- web/service/base.ts | 18 +- web/themes/dark.css | 1 + web/themes/light.css | 1 + web/themes/tailwind-theme-var-define.ts | 1 + web/types/workflow.ts | 53 + 156 files changed, 11151 insertions(+), 5598 deletions(-) delete mode 100644 api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py delete mode 100644 api/core/app/apps/workflow/workflow_event_trigger_callback.py create mode 100644 api/core/app/apps/workflow_app_runner.py delete mode 100644 api/core/app/task_pipeline/workflow_iteration_cycle_manage.py create mode 100644 api/core/workflow/graph_engine/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/base_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_manager.py create mode 100644 api/core/workflow/graph_engine/entities/__init__.py create mode 100644 api/core/workflow/graph_engine/entities/event.py create mode 100644 api/core/workflow/graph_engine/entities/graph.py create mode 100644 api/core/workflow/graph_engine/entities/graph_init_params.py create mode 100644 api/core/workflow/graph_engine/entities/graph_runtime_state.py create mode 100644 api/core/workflow/graph_engine/entities/next_graph_node.py create mode 100644 api/core/workflow/graph_engine/entities/run_condition.py create mode 100644 api/core/workflow/graph_engine/entities/runtime_route_state.py create mode 100644 api/core/workflow/graph_engine/graph_engine.py create mode 100644 api/core/workflow/nodes/answer/answer_stream_generate_router.py create mode 100644 api/core/workflow/nodes/answer/answer_stream_processor.py create mode 100644 api/core/workflow/nodes/answer/base_stream_processor.py create mode 100644 api/core/workflow/nodes/end/end_stream_generate_router.py create mode 100644 api/core/workflow/nodes/end/end_stream_processor.py create mode 100644 api/core/workflow/nodes/event.py create mode 100644 api/core/workflow/nodes/iteration/iteration_start_node.py create mode 100644 api/core/workflow/nodes/node_mapping.py create mode 100644 api/core/workflow/utils/condition/__init__.py create mode 100644 api/core/workflow/utils/condition/entities.py create mode 100644 api/core/workflow/utils/condition/processor.py create mode 100644 api/core/workflow/workflow_entry.py create mode 100644 api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py create mode 100644 web/app/components/workflow/limit-tips.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/next-step/container.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/next-step/operator.tsx create mode 100644 web/app/components/workflow/nodes/iteration-start/constants.ts create mode 100644 web/app/components/workflow/nodes/iteration-start/default.ts create mode 100644 web/app/components/workflow/nodes/iteration-start/index.tsx create mode 100644 web/app/components/workflow/nodes/iteration-start/types.ts delete mode 100644 web/app/components/workflow/nodes/iteration/insert-block.tsx diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 2d540ca584a927..e03dfeb27ca3ed 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.7.3", + default="0.8.0", ) COMMIT_SHA: str = Field( diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index e7c9ebe09749de..638cc07461d714 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,12 +4,10 @@ import threading import uuid from collections.abc import Generator -from typing import Literal, Union, overload +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -20,20 +18,15 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -60,13 +53,14 @@ def generate( ) -> dict: ... def generate( - self, app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ): + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -154,7 +148,8 @@ def single_iteration_generate(self, app_model: App, node_id: str, user: Account, args: dict, - stream: bool = True): + stream: bool = True) \ + -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -171,16 +166,6 @@ def single_iteration_generate(self, app_model: App, if args.get('inputs') is None: raise ValueError('inputs is required') - extras = { - "auto_generate_conversation_name": False - } - - # get conversation - conversation = None - conversation_id = args.get('conversation_id') - if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) - # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( app_model=app_model, @@ -191,14 +176,16 @@ def single_iteration_generate(self, app_model: App, application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - conversation_id=conversation.id if conversation else None, + conversation_id=None, inputs={}, query='', files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={ + "auto_generate_conversation_name": False + }, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args['inputs'] @@ -211,17 +198,28 @@ def single_iteration_generate(self, app_model: App, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=None, stream=stream ) def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation | None = None, - stream: bool = True): + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True) \ + -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ is_first_conversation = False if not conversation: is_first_conversation = True @@ -236,7 +234,7 @@ def _generate(self, *, # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - # db.session.refresh(conversation) + db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -248,67 +246,12 @@ def _generate(self, *, message_id=message.id ) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - # Create conversation variables if they don't exist. - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] - - session.commit() - - # Increment dialogue count. - conversation.dialogue_count += 1 - - conversation_id = conversation.id - conversation_dialogue_count = conversation.dialogue_count - db.session.commit() - db.session.refresh(conversation) - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - contexts.workflow_variable_pool.set(variable_pool) - # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, + 'conversation_id': conversation.id, 'message_id': message.id, 'context': contextvars.copy_context(), }) @@ -334,6 +277,7 @@ def _generate(self, *, def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, + conversation_id: str, message_id: str, context: contextvars.Context) -> None: """ @@ -349,28 +293,19 @@ def _generate_worker(self, flask_app: Flask, var.set(val) with flask_app.app_context(): try: - runner = AdvancedChatAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - # get message - message = self._get_message(message_id) - - # chatbot app - runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 5dc03979cf3b4b..4da3d093d276c5 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,49 +1,67 @@ import logging import os -import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueStopEvent, + QueueTextChunkEvent, +) from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from models import App, Message, Workflow +from models.model import App, Conversation, EndUser, Message +from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) -class AdvancedChatAppRunner(AppRunner): +class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ AdvancedChat Application Runner """ - def run( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message, + def __init__( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message ) -> None: """ - Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :param conversation: conversation :param message: message + """ + super().__init__(queue_manager) + + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.message = message + + def run(self) -> None: + """ + Run application :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) app_record = db.session.query(App).filter(App.id == app_config.app_id).first() @@ -54,101 +72,133 @@ def run( if not workflow: raise ValueError('Workflow not initialized') - inputs = application_generate_entity.inputs - query = application_generate_entity.query + user_id = None + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id - # moderation - if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id, - ): - return + workflow_callbacks: list[WorkflowCallback] = [] + if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + workflow_callbacks.append(WorkflowLoggingCallback()) - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - ): - return + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files - db.session.close() + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id + ): + return - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity + ): + return - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): - workflow_callbacks.append(WorkflowLoggingCallback()) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] - # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, - ) + session.commit() - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') + # Increment dialogue count. + self.conversation.dialogue_count += 1 - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') + conversation_dialogue_count = self.conversation.dialogue_count + db.session.commit() - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + } - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) + + db.session.close() + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() + generator = workflow_entry.run( + callbacks=workflow_callbacks, ) - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) def handle_input_moderation( - self, - queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str, + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str ) -> bool: """ Handle input moderation - :param queue_manager: application queue manager :param app_record: app record :param app_generate_entity: application generate entity :param inputs: inputs @@ -167,30 +217,23 @@ def handle_input_moderation( message_id=message_id, ) except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=str(e), - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION ) return True return False - def handle_annotation_reply( - self, - app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity, - ) -> bool: + def handle_annotation_reply(self, app_record: App, + message: Message, + query: str, + app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: """ Handle annotation reply :param app_record: app record :param message: message :param query: query - :param queue_manager: application queue manager :param app_generate_entity: application generate entity """ # annotation reply @@ -203,37 +246,32 @@ def handle_annotation_reply( ) if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER + self._publish_event( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) ) - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=annotation_reply.content, - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _stream_output( - self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy - ) -> None: + def _complete_with_stream_output(self, + text: str, + stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param queue_manager: application queue manager :param text: text - :param stream: stream :return: """ - if stream: - index = 0 - for token in text: - queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) + self._publish_event( + QueueTextChunkEvent( + text=text + ) + ) - queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) + self._publish_event( + QueueStopEvent(stopped_by=stopped_by) + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2b3596ded261e5..fb013cd1b1f7ed 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,9 +2,8 @@ import logging import time from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union -import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -22,6 +21,9 @@ QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, @@ -31,34 +33,28 @@ QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, - ChatflowStreamGenerateRoute, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, + WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.file.file_obj import FileVar -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) @@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: AdvancedChatTaskState + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] - # Deprecated _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, queue_manager: AppQueueManager, conversation: Conversation, @@ -106,7 +101,6 @@ def __init__( self._workflow = workflow self._conversation = conversation self._message = message - # Deprecated self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, @@ -114,12 +108,8 @@ def __init__( SystemVariableKey.USER_ID: user_id, } - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) + self._task_state = WorkflowTaskState() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None def process(self): @@ -140,6 +130,7 @@ def process(self): generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) + if self._stream: return self._to_stream_response(generator) else: @@ -199,17 +190,18 @@ def _listenAudioMsg(self, publisher, task_id: str): def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -220,9 +212,9 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -240,34 +232,34 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if (message.event - and getattr(message.event, 'metadata', None) - and message.event.metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - elif (hasattr(message.event, 'execution_metadata') - and message.event.execution_metadata - and message.event.execution_metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - event = message.event - - if isinstance(event, QueueErrorEvent): + # init fake graph runtime state + graph_runtime_state = None + workflow_run = None + + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.workflow_run_id = workflow_run.id db.session.commit() @@ -279,133 +271,242 @@ def _process_stream_response( workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: - self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] - # reset current route position to 0 - self._task_state.current_stream_generate_state.current_route_position = 0 + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() + response = self._workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) - yield self._workflow_node_start_to_stream_response( + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) + + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - # stream outputs when node finished - generator = self._generate_stream_outputs_when_node_finished() - if generator: - yield from generator + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event ) - if workflow_run: - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if event.outputs else None, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + elif isinstance(event, QueueWorkflowFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + yield self._error_to_stream_response(self._handle_error(err_event, self._message)) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message() + self._save_message(graph_runtime_state=graph_runtime_state) yield self._message_end_to_stream_response() + break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) + + self._refetch_message() + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) + + self._refetch_message() + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text yield self._message_to_stream_response(delta_text, self._message.id) elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message(graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() else: continue - if publisher: - publisher.publish(None) + + # publish None when task finished + if tts_publisher: + tts_publisher.publish(None) + if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self) -> None: + def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: """ Save message. :return: """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ if self._task_state.metadata else None - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) - + if graph_runtime_state and graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -432,7 +533,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras['metadata'] = self._task_state.metadata.copy() + + if 'annotation_reply' in extras['metadata']: + del extras['metadata']['annotation_reply'] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, @@ -440,323 +544,6 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: **extras ) - def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - answer_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.ANSWER.value - ] - - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_config in answer_node_configs: - # get generate route for stream output - answer_node_id = node_config['id'] - generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get answer start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - # check if it's the first node in the iteration - target_node = next((node for node in nodes if node.get('id') == target_node_id), None) - if not target_node: - return [] - - node_iteration_id = target_node.get('data', {}).get('iteration_id') - # get iteration start node id - for node in nodes: - if node.get('id') == node_iteration_id: - if node.get('data', {}).get('start_node_id') == target_node_id: - return [target_node_id] - - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position: - ] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) - if should_direct_answer: - continue - - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - break - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - value = None - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - route_chunk_node_id = value_selector[0] - - if route_chunk_node_id == 'sys': - # system variable - value = contexts.workflow_variable_pool.get().get(value_selector) - if value: - value = value.text - elif route_chunk_node_id in self._iteration_nested_relations: - # it's a iteration variable - if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: - continue - iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] - iterator = iteration_state.inputs - if not iterator: - continue - iterator_selector = iterator.get('iterator_selector', []) - if value_selector[1] == 'index': - value = iteration_state.current_index - elif value_selector[1] == 'item': - value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( - iterator_selector - ) else None - else: - # check chunk node id is before current node id or equal to current node id - if route_chunk_node_id not in self._task_state.ran_node_execution_infos: - break - - latest_node_execution_info = self._task_state.latest_node_execution_info - - # get route chunk node execution info - route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] - if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): - # only LLM support chunk stream output - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - # get route chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id - ).first() - - outputs = route_chunk_node_execution.outputs_dict - - # get value from outputs - value = None - for key in value_selector[1:]: - if not value: - value = outputs.get(key) if outputs else None - else: - value = value.get(key) - - if value is not None: - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - if text: - self._task_state.answer += text - yield self._message_to_stream_response(text, self._message.id) - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return True - - if 'node_id' not in event.metadata: - return True - - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - route_chunk = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position] - - if route_chunk.type != 'var': - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - return False - - return True - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. @@ -782,3 +569,12 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: self._output_moderation_handler.append_new_token(text) return False + + def _refetch_message(self) -> None: + """ + Refetch message. + :return: + """ + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if message: + self._message = message diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py deleted file mode 100644 index 8d43155a0886bf..00000000000000 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager._publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager._publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self._queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 1165314a7f2bd8..a196d36be5f1f7 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert(cls, response: Union[ AppBlockingResponse, Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom): + ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 2c5feaaaafb153..60216959a85b21 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time -from collections.abc import Generator -from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -347,7 +347,7 @@ def moderation_for_inputs( self, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, ) -> tuple[bool, dict, str]: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 26bb6c0f4f70b8..4347e5277bb320 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import threading import uuid from collections.abc import Generator -from typing import Literal, Union, overload +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -40,6 +40,8 @@ def generate( args: dict, invoke_from: InvokeFrom, stream: Literal[True] = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> Generator[str, None, None]: ... @overload @@ -50,16 +52,20 @@ def generate( args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> dict: ... def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ): """ Generate App response. @@ -71,6 +77,7 @@ def generate( :param invoke_from: invoke from source :param stream: is stream :param call_depth: call depth + :param workflow_thread_pool_id: workflow thread pool id """ inputs = args['inputs'] @@ -118,16 +125,19 @@ def generate( application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, + workflow_thread_pool_id=workflow_thread_pool_id ) def _generate( - self, app_model: App, + self, *, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + workflow_thread_pool_id: Optional[str] = None + ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -137,6 +147,7 @@ def _generate( :param application_generate_entity: application generate entity :param invoke_from: invoke from source :param stream: is stream + :param workflow_thread_pool_id: workflow thread pool id """ # init queue manager queue_manager = WorkflowAppQueueManager( @@ -148,10 +159,11 @@ def _generate( # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), + 'workflow_thread_pool_id': workflow_thread_pool_id }) worker_thread.start() @@ -175,7 +187,7 @@ def single_iteration_generate(self, app_model: App, node_id: str, user: Account, args: dict, - stream: bool = True): + stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -192,10 +204,6 @@ def single_iteration_generate(self, app_model: App, if args.get('inputs') is None: raise ValueError('inputs is required') - extras = { - "auto_generate_conversation_name": False - } - # convert to app config app_config = WorkflowAppConfigManager.get_app_config( app_model=app_model, @@ -211,7 +219,9 @@ def single_iteration_generate(self, app_model: App, user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={ + "auto_generate_conversation_name": False + }, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args['inputs'] @@ -231,12 +241,14 @@ def single_iteration_generate(self, app_model: App, def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, - context: contextvars.Context) -> None: + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None) -> None: """ Generate worker in a new thread. :param flask_app: Flask app :param application_generate_entity: application generate entity :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id :return: """ for var, val in context.items(): @@ -244,22 +256,13 @@ def _generate_worker(self, flask_app: Flask, with flask_app.app_context(): try: # workflow app - runner = WorkflowAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id + ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: @@ -271,14 +274,14 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e388d0184b83b1..9d48db7546d092 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,46 +4,61 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig -from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import WorkflowType logger = logging.getLogger(__name__) -class WorkflowAppRunner: +class WorkflowAppRunner(WorkflowBasedAppRunner): """ Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: @@ -53,80 +68,64 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_mana if not workflow: raise ValueError('Workflow not initialized') - inputs = application_generate_entity.inputs - files = application_generate_entity.files - db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - + workflow_callbacks: list[WorkflowCallback] = [] if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) - # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: - # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, - variable_pool=variable_pool, - ) + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + } - if not app_record.workflow_id: - raise ValueError('Workflow not initialized') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() + generator = workflow_entry.run( + callbacks=workflow_callbacks ) - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index de8542d7b9f859..00b3b9f57e9024 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -15,10 +16,12 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, - QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, @@ -32,19 +35,16 @@ MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, - TextReplaceStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowStreamGenerateNodes, + WorkflowStartStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -52,8 +52,8 @@ Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, + WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -96,11 +95,7 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, SystemVariableKey.USER_ID: user_id } - self._task_state = WorkflowTaskState( - iteration_nested_node_ids=[] - ) - self._stream_generate_nodes = self._get_stream_generate_nodes() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) + self._task_state = WorkflowTaskState() def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -129,23 +124,20 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status, - outputs=workflow_run.outputs_dict, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at) ) ) @@ -161,9 +153,13 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) To stream response. :return: """ + workflow_run_id = None for stream_response in generator: + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id + yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run_id, + workflow_run_id=workflow_run_id, stream_response=stream_response ) @@ -178,17 +174,18 @@ def _listenAudioMsg(self, publisher, task_id: str): def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -198,9 +195,9 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -218,69 +215,159 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: - self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) + + response = self._workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - yield self._workflow_node_finish_to_stream_response( + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) + + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, trace_manager=trace_manager + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, ) # save workflow app log @@ -295,22 +382,17 @@ def _process_stream_response( if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response(delta_text) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._text_replace_to_stream_response(event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() else: continue - if publisher: - publisher.publish(None) + if tts_publisher: + tts_publisher.publish(None) def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: @@ -329,15 +411,15 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: # not save log for debugging return - workflow_app_log = WorkflowAppLog( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, - created_from=created_from.value, - created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), - created_by=self._user.id, - ) + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by = self._user.id + db.session.add(workflow_app_log) db.session.commit() db.session.close() @@ -354,180 +436,3 @@ def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: ) return response - - def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: - """ - Text replace to stream response. - :param text: text - :return: - """ - return TextReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - text=TextReplaceStreamResponse.Data(text=text) - ) - - def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: - """ - Get stream generate nodes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - end_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.END.value - ] - - # parse stream output node value selectors of end nodes - stream_generate_routes = {} - for node_config in end_node_configs: - # get generate route for stream output - end_node_id = node_config['id'] - generate_nodes = EndNode.extract_generate_nodes(graph, node_config) - start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( - end_node_id=end_node_id, - stream_node_ids=generate_nodes - ) - - return stream_generate_routes - - def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get end start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids - - for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): - if node_id not in stream_node_ids: - continue - - node_execution_info = self._task_state.ran_node_execution_infos[node_id] - - # get chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() - - if not route_chunk_node_execution: - continue - - outputs = route_chunk_node_execution.outputs_dict - - if not outputs: - continue - - # get value from outputs - text = outputs.get('text') - - if text: - self._task_state.answer += text - yield self._text_chunk_to_stream_response(text) - - db.session.close() - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return False - - if 'node_id' not in event.metadata: - return False - - node_id = event.metadata.get('node_id') - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - return True - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py deleted file mode 100644 index 4472a7e9b5a85c..00000000000000 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager.publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager.publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - pass diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 00000000000000..17097268876a67 --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,379 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + # init graph + graph = Graph.init( + graph_config=graph_config + ) + + if not graph: + raise ValueError('graph not found in workflow') + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError('workflow graph not found') + + graph_config = cast(dict[str, Any], graph_config) + + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + + # filter nodes only in iteration + node_configs = [ + node for node in graph_config.get('nodes', []) + if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + ] + + graph_config['nodes'] = node_configs + + node_ids = [node.get('id') for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge for edge in graph_config.get('edges', []) + if (edge.get('source') is None or edge.get('source') in node_ids) + and (edge.get('target') is None or edge.get('target') in node_ids) + ] + + graph_config['edges'] = edge_configs + + # init graph + graph = Graph.init( + graph_config=graph_config, + root_node_id=node_id + ) + + if not graph: + raise ValueError('graph not found in workflow') + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get('id') == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError('iteration node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get('data', {})) + ) + + return graph, variable_pool + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent( + graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state + ) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event( + QueueWorkflowSucceededEvent(outputs=event.outputs) + ) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event( + QueueWorkflowFailedEvent(error=event.error) + ) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result else {}, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result + and event.route_node_state.node_run_result.error + else "Unknown error", + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content, + from_variable_selector=event.from_variable_selector, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunSucceededEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + error=event.error + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish( + event, + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 2e6431d6d05f4d..4e8f3644b1bcb2 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -1,10 +1,24 @@ from typing import Optional -from core.app.entities.queue_entities import AppQueueEvent from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: self.current_node_id = None - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self.print_text("\n[on_workflow_run_started]", color='pink') - - def on_workflow_run_succeeded(self) -> None: + def on_event( + self, + event: GraphEngineEvent + ) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[GraphRunStartedEvent]", color='pink') + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[GraphRunSucceededEvent]", color='green') + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started( + event=event + ) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded( + event=event + ) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed( + event=event + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk( + event=event + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self.on_workflow_parallel_started( + event=event + ) + elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): + self.on_workflow_parallel_completed( + event=event + ) + elif isinstance(event, IterationRunStartedEvent): + self.on_workflow_iteration_started( + event=event + ) + elif isinstance(event, IterationRunNextEvent): + self.on_workflow_iteration_next( + event=event + ) + elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): + self.on_workflow_iteration_completed( + event=event + ) + else: + self.print_text(f"\n[{event.__class__.__name__}]", color='blue') + + def on_workflow_node_execute_started( + self, + event: NodeRunStartedEvent + ) -> None: """ - Workflow run succeeded + Workflow node execute started """ - self.print_text("\n[on_workflow_run_succeeded]", color='green') + self.print_text("\n[NodeRunStartedEvent]", color='yellow') + self.print_text(f"Node ID: {event.node_id}", color='yellow') + self.print_text(f"Node Title: {event.node_data.title}", color='yellow') + self.print_text(f"Type: {event.node_type.value}", color='yellow') - def on_workflow_run_failed(self, error: str) -> None: + def on_workflow_node_execute_succeeded( + self, + event: NodeRunSucceededEvent + ) -> None: """ - Workflow run failed + Workflow node execute succeeded """ - self.print_text("\n[on_workflow_run_failed]", color='red') - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunSucceededEvent]", color='green') + self.print_text(f"Node ID: {event.node_id}", color='green') + self.print_text(f"Node Title: {event.node_data.title}", color='green') + self.print_text(f"Type: {event.node_type.value}", color='green') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color='green') + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color='green') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color='green') + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", + color='green') + + def on_workflow_node_execute_failed( + self, + event: NodeRunFailedEvent + ) -> None: """ - Workflow node execute started + Workflow node execute failed """ - self.print_text("\n[on_workflow_node_execute_started]", color='yellow') - self.print_text(f"Node ID: {node_id}", color='yellow') - self.print_text(f"Type: {node_type.value}", color='yellow') - self.print_text(f"Index: {node_run_index}", color='yellow') - if predecessor_node_id: - self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunFailedEvent]", color='red') + self.print_text(f"Node ID: {event.node_id}", color='red') + self.print_text(f"Node Title: {event.node_data.title}", color='red') + self.print_text(f"Type: {event.node_type.value}", color='red') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color='red') + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color='red') + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color='red') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color='red') + + def on_node_text_chunk( + self, + event: NodeRunStreamChunkEvent + ) -> None: """ - Workflow node execute succeeded + Publish text chunk """ - self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') - self.print_text(f"Node ID: {node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') - self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", - color='green') - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id + self.print_text('\n[NodeRunStreamChunkEvent]') + self.print_text(f"Node ID: {route_node_state.node_id}") + + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") + + self.print_text(event.chunk_content, color="pink", end="") + + def on_workflow_parallel_started( + self, + event: ParallelBranchRunStartedEvent + ) -> None: """ - Workflow node execute failed + Publish parallel started """ - self.print_text("\n[on_workflow_node_execute_failed]", color='red') - self.print_text(f"Node ID: {node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') - self.print_text(f"Error: {error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') + self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') + self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + def on_workflow_parallel_completed( + self, + event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + ) -> None: """ - Publish text chunk + Publish parallel completed """ - if not self.current_node_id or self.current_node_id != node_id: - self.current_node_id = node_id - self.print_text('\n[on_node_text_chunk]') - self.print_text(f"Node ID: {node_id}") - self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") + if isinstance(event, ParallelBranchRunSucceededEvent): + color = 'blue' + elif isinstance(event, ParallelBranchRunFailedEvent): + color = 'red' - self.print_text(text, color="pink", end="") + self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) + self.print_text(f"Parallel ID: {event.parallel_id}", color=color) + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: + if isinstance(event, ParallelBranchRunFailedEvent): + self.print_text(f"Error: {event.error}", color=color) + + def on_workflow_iteration_started( + self, + event: IterationRunStartedEvent + ) -> None: """ Publish iteration started """ - self.print_text("\n[on_workflow_iteration_started]", color='blue') - self.print_text(f"Node ID: {node_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color='blue') + self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[dict]) -> None: + def on_workflow_iteration_next( + self, + event: IterationRunNextEvent + ) -> None: """ Publish iteration next """ - self.print_text("\n[on_workflow_iteration_next]", color='blue') + self.print_text("\n[IterationRunNextEvent]", color='blue') + self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') + self.print_text(f"Iteration Index: {event.index}", color='blue') - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: + def on_workflow_iteration_completed( + self, + event: IterationRunSucceededEvent | IterationRunFailedEvent + ) -> None: """ Publish iteration completed """ - self.print_text("\n[on_workflow_iteration_completed]", color='blue') - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self.print_text("\n[on_workflow_event]", color='blue') - self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') + self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') + self.print_text(f"Node ID: {event.iteration_id}", color='blue') def print_text( self, text: str, color: Optional[str] = None, end: str = "\n" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15348251f2de35..4c86b7eee13e3e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,3 +1,4 @@ +from datetime import datetime from enum import Enum from typing import Any, Optional @@ -5,7 +6,8 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class QueueEvent(str, Enum): @@ -31,6 +33,9 @@ class QueueEvent(str, Enum): ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" MESSAGE_FILE = "message_file" + PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" + PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" + PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" ERROR = "error" PING = "ping" STOP = "stop" @@ -38,7 +43,7 @@ class QueueEvent(str, Enum): class AppQueueEvent(BaseModel): """ - QueueEvent entity + QueueEvent abstract entity """ event: QueueEvent @@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel): class QueueLLMChunkEvent(AppQueueEvent): """ QueueLLMChunkEvent entity + Only for basic mode apps """ event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk @@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent): QueueIterationStartEvent entity """ event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: dict = None + inputs: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict] = None + metadata: Optional[dict[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): """ @@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent): event: QueueEvent = QueueEvent.ITERATION_NEXT index: int + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ - event:QueueEvent = QueueEvent.ITERATION_COMPLETED + event: QueueEvent = QueueEvent.ITERATION_COMPLETED + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - outputs: dict + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + error: Optional[str] = None + class QueueTextChunkEvent(AppQueueEvent): """ @@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - metadata: Optional[dict] = None + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAgentMessageEvent(AppQueueEvent): @@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAnnotationReplyEvent(AppQueueEvent): @@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): QueueWorkflowStartedEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): QueueWorkflowSucceededEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_STARTED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 predecessor_node_id: Optional[str] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime class QueueNodeSucceededEvent(AppQueueEvent): @@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_SUCCEEDED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData - - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None - execution_metadata: Optional[dict] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None @@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_FAILED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData - - inputs: Optional[dict] = None - outputs: Optional[dict] = None - process_data: Optional[dict] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None error: str @@ -274,10 +360,23 @@ class StopBy(Enum): event: QueueEvent = QueueEvent.STOP stopped_by: StopBy + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', + QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', + QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', + QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + } + + return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + class QueueMessage(BaseModel): """ - QueueMessage entity + QueueMessage abstract entity """ task_id: str app_mode: str @@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage): WorkflowQueueMessage entity """ pass + + +class QueueParallelBranchRunStartedEvent(AppQueueEvent): + """ + QueueParallelBranchRunStartedEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunSucceededEvent(AppQueueEvent): + """ + QueueParallelBranchRunSucceededEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunFailedEvent(AppQueueEvent): + """ + QueueParallelBranchRunFailedEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7bc55989843305..7cab6ca4e0d4bd 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -3,40 +3,11 @@ from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import GenerateRouteChunk from models.workflow import WorkflowNodeExecutionStatus -class WorkflowStreamGenerateNodes(BaseModel): - """ - WorkflowStreamGenerateNodes entity - """ - end_node_id: str - stream_node_ids: list[str] - - -class ChatflowStreamGenerateRoute(BaseModel): - """ - ChatflowStreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - -class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class TaskState(BaseModel): """ TaskState entity @@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState): """ answer: str = "" - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None - - iteration_nested_node_ids: list[str] = None - - -class AdvancedChatTaskState(WorkflowTaskState): - """ - AdvancedChatTaskState entity - """ - usage: LLMUsage - - current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None - class StreamEvent(Enum): """ @@ -97,6 +47,8 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + PARALLEL_BRANCH_STARTED = "parallel_branch_started" + PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -267,6 +219,11 @@ class Data(BaseModel): inputs: Optional[dict] = None created_at: int extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -286,7 +243,12 @@ def to_ignore_detail_dict(self): "predecessor_node_id": self.data.predecessor_node_id, "inputs": None, "created_at": self.data.created_at, - "extras": {} + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, } } @@ -316,6 +278,11 @@ class Data(BaseModel): created_at: int finished_at: int files: Optional[list[dict]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -342,9 +309,58 @@ def to_ignore_detail_dict(self): "execution_metadata": None, "created_at": self.data.created_at, "finished_at": self.data.finished_at, - "files": [] + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, } } + + +class ParallelBranchStartStreamResponse(StreamResponse): + """ + ParallelBranchStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED + workflow_run_id: str + data: Data + + +class ParallelBranchFinishedStreamResponse(StreamResponse): + """ + ParallelBranchFinishedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + status: str + error: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED + workflow_run_id: str + data: Data class IterationNodeStartStreamResponse(StreamResponse): @@ -364,6 +380,8 @@ class Data(BaseModel): extras: dict = {} metadata: dict = {} inputs: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -387,6 +405,8 @@ class Data(BaseModel): created_at: int pre_iteration_output: Optional[Any] = None extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -408,8 +428,8 @@ class Data(BaseModel): title: str outputs: Optional[dict] = None created_at: int - extras: dict = None - inputs: dict = None + extras: Optional[dict] = None + inputs: Optional[dict] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float @@ -417,6 +437,8 @@ class Data(BaseModel): execution_metadata: Optional[dict] = None finished_at: int steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ - workflow_run_id: str + workflow_run_id: Optional[str] = None class AppBlockingResponse(BaseModel): @@ -562,25 +584,3 @@ class Data(BaseModel): workflow_run_id: str data: Data - - -class WorkflowIterationState(BaseModel): - """ - WorkflowIterationState entity - """ - - class Data(BaseModel): - """ - Data entity - """ - parent_iteration_id: Optional[str] = None - iteration_id: str - current_index: int - iteration_steps_boundary: list[int] = None - node_execution_id: str - started_at: float - inputs: dict = None - total_tokens: int = 0 - node_data: BaseNodeData - - current_iterations: dict[str, Data] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a3c1fb58245a11..2f74a180d1266c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -68,16 +68,18 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) if message: - message = db.session.query(Message).filter(Message.id == message.id).first() - err_desc = self._error_to_desc(err) - message.status = 'error' - message.error = err_desc + refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.commit() + if refetch_message: + err_desc = self._error_to_desc(err) + refetch_message.status = 'error' + refetch_message.error = err_desc + + db.session.commit() return err - def _error_to_desc(cls, e: Exception) -> str: + def _error_to_desc(self, e: Exception) -> str: """ Error to desc. :param e: exception diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 76c50809cf340c..8ff50dd1747c0a 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -8,7 +8,6 @@ AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, @@ -16,11 +15,11 @@ QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager @@ -36,7 +35,7 @@ class MessageCycleManage: AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] - _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + _task_state: Union[EasyUITaskState, WorkflowTaskState] def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: """ @@ -45,6 +44,9 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> :param query: query :return: thread """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) @@ -52,7 +54,7 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> if auto_generate_conversation_name and is_first_message: # start generate thread thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'conversation_id': conversation.id, 'query': query }) @@ -75,6 +77,9 @@ def _generate_conversation_name_worker(self, .first() ) + if not conversation: + return + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: @@ -121,34 +126,13 @@ def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> No if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata['retriever_resources'] = event.retriever_resources - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event :return: """ - message_file: MessageFile = ( + message_file = ( db.session.query(MessageFile) .filter(MessageFile.id == event.message_file_id) .first() diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4935c43ac437e4..ed3225310a8ec4 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,33 +1,41 @@ import json import time from datetime import datetime, timezone -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, ) from core.app.entities.task_entities import ( - NodeExecutionInfo, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, + WorkflowTaskState, ) -from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -41,54 +49,56 @@ WorkflowRunStatus, WorkflowRunTriggeredFrom, ) -from services.workflow_service import WorkflowService - - -class WorkflowCycleManage(WorkflowIterationCycleManage): - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 + + +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _workflow_system_variables: dict[SystemVariableKey, Any] + + def _handle_workflow_run_start(self) -> WorkflowRun: + max_sequence = ( + db.session.query(db.func.max(WorkflowRun.sequence_number)) + .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) + .filter(WorkflowRun.app_id == self._workflow.app_id) + .scalar() + or 0 + ) new_sequence_number = max_sequence + 1 - inputs = {**user_inputs} - for key, value in (system_inputs or {}).items(): + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): if key.value == 'conversation': continue inputs[f'sys.{key.value}'] = value - inputs = WorkflowEngineManager.handle_special_values(inputs) + + inputs = WorkflowEntry.handle_special_values(inputs) + + triggered_from= ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps(inputs), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id + workflow_run = WorkflowRun() + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING.value + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value ) + workflow_run.created_by = self._user.id db.session.add(workflow_run) db.session.commit() @@ -97,33 +107,37 @@ def _init_workflow_run(self, workflow: Workflow, return workflow_run - def _workflow_run_success( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_success( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, outputs: Optional[str] = None, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run success :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param outputs: outputs :param conversation_id: conversation id :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.outputs = outputs - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) - db.session.close() if trace_manager: trace_manager.add_trace_task( @@ -135,34 +149,58 @@ def _workflow_run_success( ) ) + db.session.close() + return workflow_run - def _workflow_run_failed( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_failed( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, error: str, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run failed :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param status: status :param error: error message :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = status.value workflow_run.error = error - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() + + running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value + ).all() + + for workflow_node_execution in running_workflow_node_executions: + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + db.session.commit() + db.session.refresh(workflow_run) db.session.close() @@ -178,39 +216,24 @@ def _workflow_run_failed( return workflow_run - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ + def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.add(workflow_node_execution) db.session.commit() @@ -219,33 +242,26 @@ def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, return workflow_node_execution - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: """ Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata + :param event: queue node succeeded event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -253,33 +269,24 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode return workflow_node_execution - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None - ) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: """ Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message + :param event: queue node failed event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.error = event.error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -287,8 +294,13 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE return workflow_node_execution - def _workflow_start_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + ################################################# + # to stream responses # + ################################################# + + def _workflow_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -302,13 +314,14 @@ def _workflow_start_to_stream_response(self, task_id: str, id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, - created_at=int(workflow_run.created_at.timestamp()) - ) + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + ), ) - def _workflow_finish_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -320,16 +333,16 @@ def _workflow_finish_to_stream_response(self, task_id: str, created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + 'id': created_by_account.id, + 'name': created_by_account.name, + 'email': created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, + 'id': created_by_end_user.id, + 'user': created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -348,14 +361,13 @@ def _workflow_finish_to_stream_response(self, task_id: str, created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + ), ) - def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution) \ - -> NodeStartStreamResponse: + def _workflow_node_start_to_stream_response( + self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeStartStreamResponse]: """ Workflow node start to stream response. :param event: queue node started event @@ -363,6 +375,9 @@ def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -374,8 +389,13 @@ def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()) - ) + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) # extras logic @@ -384,19 +404,27 @@ def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, response.data.extras['icon'] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, - provider_id=node_data.provider_id + provider_id=node_data.provider_id, ) return response - def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + def _workflow_node_finish_to_stream_response( + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. + :param event: queue node succeeded or failed event :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -416,181 +444,155 @@ def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_e execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) - - def _handle_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs=self._workflow_system_variables + + def _workflow_parallel_branch_start_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: + """ + Workflow parallel branch start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run started event + :return: + """ + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + created_at=int(time.time()), + ) ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id + + def _workflow_parallel_branch_finished_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent + ) -> ParallelBranchFinishedStreamResponse: + """ + Workflow parallel branch finished to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run succeeded or failed event + :return: + """ + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), + ) ) - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() + def _workflow_iteration_start_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + """ + Workflow iteration start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration start event + :return: + """ + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) ) - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - - execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None - - if self._iteration_state and self._iteration_state.current_iterations: - if not execution_metadata: - execution_metadata = {} - current_iteration_data = None - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if data.parent_iteration_id == None: - current_iteration_data = data - break - - if current_iteration_data: - execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id - execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index - - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=execution_metadata + def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + """ + Workflow iteration next to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration next event + :return: + """ + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) + ) - if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if self._iteration_state: - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error, - inputs=event.inputs, - process_data=event.process_data, + def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + """ + Workflow iteration completed to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration completed event + :return: + """ + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, outputs=event.outputs, - execution_metadata=execution_metadata + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) - - db.session.close() - - return workflow_node_execution - - def _handle_workflow_finished( - self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None - ) -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - - if conversation_id is None: - conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - latest_node_execution_info = self._task_state.latest_node_execution_info - if latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() - if (workflow_node_execution - and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=latest_node_execution_info.start_at, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run + ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: """ @@ -647,3 +649,40 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: return value.to_dict() return None + + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Refetch workflow run + :param workflow_run_id: workflow run id + :return: + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.id == workflow_run_id).first() + + if not workflow_run: + raise Exception(f'Workflow run not found: {workflow_run_id}') + + return workflow_run + + def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + """ + Refetch workflow node execution + :param node_execution_id: workflow node execution id + :return: + """ + workflow_node_execution = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, + WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, + WorkflowNodeExecution.workflow_id == self._workflow.id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.node_execution_id == node_execution_id, + ) + .first() + ) + + if not workflow_node_execution: + raise Exception(f'Workflow node execution not found: {node_execution_id}') + + return workflow_node_execution \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index bd98c82720f039..e69de29bb2d1d6 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -1,16 +0,0 @@ -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariableKey -from models.account import Account -from models.model import EndUser -from models.workflow import Workflow - - -class WorkflowCycleStateManager: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py deleted file mode 100644 index aff187071417c7..00000000000000 --- a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import time -from collections.abc import Generator -from datetime import datetime, timezone -from typing import Optional, Union - -from core.app.entities.queue_entities import ( - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, -) -from core.app.entities.task_entities import ( - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - NodeExecutionInfo, - WorkflowIterationState, -) -from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.workflow_engine_manager import WorkflowEngineManager -from extensions.ext_database import db -from models.workflow import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, -) - - -class WorkflowIterationCycleManage(WorkflowCycleStateManager): - _iteration_state: WorkflowIterationState = None - - def _init_iteration_state(self) -> WorkflowIterationState: - if not self._iteration_state: - self._iteration_state = WorkflowIterationState( - current_iterations={} - ) - - def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ - -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: - """ - Handle iteration to stream response - :param task_id: task id - :param event: iteration event - :return: - """ - if isinstance(event, QueueIterationStartEvent): - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs, - metadata=event.metadata - ) - ) - elif isinstance(event, QueueIterationNextEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={} - ) - ) - elif isinstance(event, QueueIterationCompletedEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - error=None, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) - - def _init_iteration_execution_from_workflow_run(self, - workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - inputs=json.dumps(inputs) if inputs else None, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - execution_metadata=json.dumps({ - 'started_run_index': node_run_index + 1, - 'current_index': 0, - 'steps_boundary': [], - }), - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: - if isinstance(event, QueueIterationStartEvent): - return self._handle_iteration_started(event) - elif isinstance(event, QueueIterationNextEvent): - return self._handle_iteration_next(event) - elif isinstance(event, QueueIterationCompletedEvent): - return self._handle_iteration_completed(event) - - def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: - self._init_iteration_state() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_iteration_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=NodeType.ITERATION, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( - parent_iteration_id=None, - iteration_id=event.node_id, - current_index=0, - iteration_steps_boundary=[], - node_execution_id=workflow_node_execution.id, - started_at=time.perf_counter(), - inputs=event.inputs, - total_tokens=0, - node_data=event.node_data - ) - - db.session.close() - - return workflow_node_execution - - def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: - if event.node_id not in self._iteration_state.current_iterations: - return - current_iteration = self._iteration_state.current_iterations[event.node_id] - current_iteration.current_index = event.index - current_iteration.iteration_steps_boundary.append(event.node_run_index) - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['current_index'] = event.index - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - db.session.close() - - def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): - if event.node_id not in self._iteration_state.current_iterations: - return - - current_iteration = self._iteration_state.current_iterations[event.node_id] - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - # remove current iteration - self._iteration_state.current_iterations.pop(event.node_id, None) - - # set latest node execution info - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.latest_node_execution_info = latest_node_execution_info - - db.session.close() - - def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: - """ - Handle iteration exception - """ - if not self._iteration_state or not self._iteration_state.current_iterations: - return - - for node_id, current_iteration in self._iteration_state.current_iterations.items(): - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - db.session.commit() - db.session.close() - - yield IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=node_id, - node_id=node_id, - node_type=NodeType.ITERATION.value, - title=current_iteration.node_data.title, - outputs={}, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a0573..59a4c103a20702 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -63,6 +63,39 @@ def empty_usage(cls): latency=0.0 ) + def plus(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Add two LLMUsage instances together. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + if self.total_tokens == 0: + return other + else: + return LLMUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + prompt_unit_price=other.prompt_unit_price, + prompt_price_unit=other.prompt_price_unit, + prompt_price=self.prompt_price + other.prompt_price, + completion_tokens=self.completion_tokens + other.completion_tokens, + completion_unit_price=other.completion_unit_price, + completion_price_unit=other.completion_price_unit, + completion_price=self.completion_price + other.completion_price, + total_tokens=self.total_tokens + other.total_tokens, + total_price=self.total_price + other.total_price, + currency=other.currency, + latency=self.latency + other.latency + ) + + def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Overload the + operator to add two LLMUsage instances. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + return self.plus(other) class LLMResult(BaseModel): """ diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 9a4d8db4e2f39d..69e28770c3415f 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -34,13 +34,13 @@ class OutputModeration(BaseModel): final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def should_direct_output(self): + def should_direct_output(self) -> bool: return self.final_output is not None - def get_final_output(self): - return self.final_output + def get_final_output(self) -> str: + return self.final_output or "" - def append_new_token(self, token: str): + def append_new_token(self, token: str) -> None: self.buffer += token if not self.thread: diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 12e498e76d8cd5..15e915628efbfa 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,7 +1,7 @@ import json import logging from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.file.file_obj import FileTransferMethod, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType @@ -18,6 +18,7 @@ class WorkflowTool(Tool): version: str workflow_entities: dict[str, Any] workflow_call_depth: int + thread_pool_id: Optional[str] = None label: str @@ -57,6 +58,7 @@ def _invoke( invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, + workflow_thread_pool_id=self.thread_pool_id ) data = result.get('data', {}) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0e15151aa49cba..6c0e906628ab82 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -128,6 +128,7 @@ def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, + thread_pool_id: Optional[str] = None ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. @@ -141,6 +142,7 @@ def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4a0188af49d6ff..4778d79ed9e5c9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -25,7 +25,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -249,7 +248,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the workflow tool runtime """ diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 564b9d3e14c15e..23e7c0c2438367 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -7,6 +7,7 @@ logger = logging.getLogger(__name__) + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 6db8adf4c21d72..9015eea85ce9db 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,116 +1,15 @@ from abc import ABC, abstractmethod -from typing import Any, Optional -from core.app.entities.queue_entities import AppQueueEvent -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self) -> None: + def on_event( + self, + event: GraphEngineEvent + ) -> None: """ - Workflow run started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - raise NotImplementedError - - @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any], - ) -> None: - """ - Publish iteration next - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - raise NotImplementedError - - @abstractmethod - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event + Published event """ raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 6bf0c11c7d723f..e7e6710cbdf8a0 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel): desc: Optional[str] = None class BaseIterationNodeData(BaseNodeData): - start_node_id: str + start_node_id: Optional[str] = None class BaseIterationState(BaseModel): iteration_node_id: str diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 025453567bfc1b..5e2a5cb4663e1b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMUsage from models import WorkflowNodeExecutionStatus @@ -28,6 +28,7 @@ class NodeType(Enum): VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' + ITERATION_START = 'iteration-start' # fake start node for iteration PARAMETER_EXTRACTOR = 'parameter-extractor' CONVERSATION_VARIABLE_ASSIGNER = 'assigner' @@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum): TOOL_INFO = 'tool_info' ITERATION_ID = 'iteration_id' ITERATION_INDEX = 'iteration_index' + PARALLEL_ID = 'parallel_id' + PARALLEL_START_NODE_ID = 'parallel_start_node_id' + PARENT_PARALLEL_ID = 'parent_parallel_id' + PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' class NodeRunResult(BaseModel): @@ -65,11 +70,32 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs + inputs: Optional[dict[str, Any]] = None # node inputs + process_data: Optional[dict[str, Any]] = None # process data + outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + + +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 27d0b672f66bf3..48a20d25ae422c 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -2,6 +2,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Union +from pydantic import BaseModel, Field, model_validator from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory @@ -16,43 +17,52 @@ CONVERSATION_VARIABLE_NODE_ID = "conversation" -class VariablePool: - def __init__( - self, - system_variables: Mapping[SystemVariableKey, Any], - user_inputs: Mapping[str, Any], - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable] | None = None, - ) -> None: - # system variables - # for example: - # { - # 'query': 'abc', - # 'files': [] - # } - - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) - - # TODO: This user inputs is not used for pool. - self.user_inputs = user_inputs +class VariablePool(BaseModel): + # Variable dictionary is a dictionary for looking up variables by their selector. + # The first element of the selector is the node id, it's the first-level key in the dictionary. + # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the + # elements of the selector except the first one. + variable_dictionary: dict[str, dict[int, Segment]] = Field( + description='Variables mapping', + default=defaultdict(dict) + ) + # TODO: This user inputs is not used for pool. + user_inputs: Mapping[str, Any] = Field( + description='User inputs', + ) + + system_variables: Mapping[SystemVariableKey, Any] = Field( + description='System variables', + ) + + environment_variables: Sequence[Variable] = Field( + description="Environment variables.", + default_factory=list + ) + + conversation_variables: Sequence[Variable] | None = None + + @model_validator(mode="after") + def val_model_after(self): + """ + Append system variables + :return: + """ # Add system variables to the variable pool - self.system_variables = system_variables - for key, value in system_variables.items(): + for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables: + for var in self.environment_variables or []: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) # Add conversation variables to the variable pool - for var in conversation_variables or []: + for var in self.conversation_variables or []: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + return self + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. @@ -79,7 +89,7 @@ def add(self, selector: Sequence[str], value: Any, /) -> None: v = factory.build_segment(value) hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = v def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -97,7 +107,7 @@ def get(self, selector: Sequence[str], /) -> Segment | None: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value @@ -118,7 +128,7 @@ def get_any(self, selector: Sequence[str], /) -> Any | None: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None def remove(self, selector: Sequence[str], /): @@ -134,7 +144,19 @@ def remove(self, selector: Sequence[str], /): if not selector: return if len(selector) == 1: - self._variable_dictionary[selector[0]] = {} + self.variable_dictionary[selector[0]] = {} return hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[selector[0]].pop(hash_key, None) + + def remove_node(self, node_id: str, /): + """ + Remove all variables associated with a given node id. + + Args: + node_id (str): The node id to remove. + + Returns: + None + """ + self.variable_dictionary.pop(node_id, None) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 9b35b8df8aa8e9..4bf4e454bbfde3 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -66,8 +66,7 @@ def __init__(self, workflow: Workflow, self.variable_pool = variable_pool self.total_tokens = 0 - self.workflow_nodes_and_results = [] - self.current_iteration_state = None self.workflow_node_steps = 1 - self.workflow_node_runs = [] \ No newline at end of file + self.workflow_node_runs = [] + self.current_iteration_state = None diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index fe79fadf66b876..07cbcd981e6c79 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,10 +1,8 @@ -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): - self.node_id = node_id - self.node_type = node_type - self.node_title = node_title + def __init__(self, node_instance: BaseNode, error: str): + self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_title} run failed: {error}") + super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/graph_engine/condition_handlers/__init__.py b/api/core/workflow/graph_engine/condition_handlers/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py new file mode 100644 index 00000000000000..4099def4e2ab57 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class RunConditionHandler(ABC): + def __init__(self, + init_params: GraphInitParams, + graph: Graph, + condition: RunCondition): + self.init_params = init_params + self.graph = graph + self.condition = condition + + @abstractmethod + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState + ) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py new file mode 100644 index 00000000000000..705eb908b1607a --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -0,0 +1,28 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class BranchIdentifyRunConditionHandler(RunConditionHandler): + + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.branch_identify: + raise Exception("Branch identify is required") + + run_result = previous_route_node_state.node_run_result + if not run_result: + return False + + if not run_result.edge_source_handle: + return False + + return self.condition.branch_identify == run_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py new file mode 100644 index 00000000000000..1edaf92da7afab --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -0,0 +1,32 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.utils.condition.processor import ConditionProcessor + + +class ConditionRunConditionHandlerHandler(RunConditionHandler): + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState + ) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.conditions: + return True + + # process condition + condition_processor = ConditionProcessor() + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, + conditions=self.condition.conditions + ) + + # Apply the logical operator for the current case + compare_result = all(group_result) + + return compare_result diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py new file mode 100644 index 00000000000000..2eb2e58bfcabdf --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -0,0 +1,35 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class ConditionManager: + @staticmethod + def get_condition_handler( + init_params: GraphInitParams, + graph: Graph, + run_condition: RunCondition + ) -> RunConditionHandler: + """ + Get condition handler + + :param init_params: init params + :param graph: graph + :param run_condition: run condition + :return: condition handler + """ + if run_condition.type == "branch_identify": + return BranchIdentifyRunConditionHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) + else: + return ConditionRunConditionHandlerHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py new file mode 100644 index 00000000000000..06dc4cb8f4393c --- /dev/null +++ b/api/core/workflow/graph_engine/entities/event.py @@ -0,0 +1,163 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class GraphEngineEvent(BaseModel): + pass + + +########################################### +# Graph Events +########################################### + + +class BaseGraphEvent(GraphEngineEvent): + pass + + +class GraphRunStartedEvent(BaseGraphEvent): + pass + + +class GraphRunSucceededEvent(BaseGraphEvent): + outputs: Optional[dict[str, Any]] = None + """outputs""" + + +class GraphRunFailedEvent(BaseGraphEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Node Events +########################################### + + +class BaseNodeEvent(GraphEngineEvent): + id: str = Field(..., description="node execution id") + node_id: str = Field(..., description="node id") + node_type: NodeType = Field(..., description="node type") + node_data: BaseNodeData = Field(..., description="node data") + route_node_state: RouteNodeState = Field(..., description="route node state") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class NodeRunStartedEvent(BaseNodeEvent): + predecessor_node_id: Optional[str] = None + """predecessor node id""" + + +class NodeRunStreamChunkEvent(BaseNodeEvent): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + + +class NodeRunRetrieverResourceEvent(BaseNodeEvent): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +class NodeRunSucceededEvent(BaseNodeEvent): + pass + + +class NodeRunFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + +########################################### +# Parallel Branch Events +########################################### + + +class BaseParallelBranchEvent(GraphEngineEvent): + parallel_id: str = Field(..., description="parallel id") + """parallel id""" + parallel_start_node_id: str = Field(..., description="parallel start node id") + """parallel start node id""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Iteration Events +########################################### + + +class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration node execution id") + iteration_node_id: str = Field(..., description="iteration node id") + iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") + iteration_node_data: BaseNodeData = Field(..., description="node data") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + + +class IterationRunStartedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + predecessor_node_id: Optional[str] = None + + +class IterationRunNextEvent(BaseIterationEvent): + index: int = Field(..., description="index") + pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + + +class IterationRunSucceededEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + +class IterationRunFailedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + error: str = Field(..., description="failed reason") + + +InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py new file mode 100644 index 00000000000000..49007b870d7f8f --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -0,0 +1,692 @@ +import uuid +from collections.abc import Mapping +from typing import Any, Optional, cast + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter +from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute +from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter +from core.workflow.nodes.end.entities import EndStreamParam + + +class GraphEdge(BaseModel): + source_node_id: str = Field(..., description="source node id") + target_node_id: str = Field(..., description="target node id") + run_condition: Optional[RunCondition] = None + """run condition""" + + +class GraphParallel(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") + start_from_node_id: str = Field(..., description="start from node id") + parent_parallel_id: Optional[str] = None + """parent parallel id""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id""" + end_to_node_id: Optional[str] = None + """end to node id""" + + +class Graph(BaseModel): + root_node_id: str = Field(..., description="root node id of the graph") + node_ids: list[str] = Field(default_factory=list, description="graph node ids") + node_id_config_mapping: dict[str, dict] = Field( + default_factory=list, + description="node configs mapping (node id: node config)" + ) + edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="graph edge mapping (source node id: edges)" + ) + reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="reverse graph edge mapping (target node id: edges)" + ) + parallel_mapping: dict[str, GraphParallel] = Field( + default_factory=dict, + description="graph parallel mapping (parallel id: parallel)" + ) + node_parallel_mapping: dict[str, str] = Field( + default_factory=dict, + description="graph node parallel mapping (node id: parallel id)" + ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( + ..., + description="answer stream generate routes" + ) + end_stream_param: EndStreamParam = Field( + ..., + description="end stream param" + ) + + @classmethod + def init(cls, + graph_config: Mapping[str, Any], + root_node_id: Optional[str] = None) -> "Graph": + """ + Init graph + + :param graph_config: graph config + :param root_node_id: root node id + :return: graph + """ + # edge configs + edge_configs = graph_config.get('edges') + if edge_configs is None: + edge_configs = [] + + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + edge_mapping: dict[str, list[GraphEdge]] = {} + reverse_edge_mapping: dict[str, list[GraphEdge]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get('source') + if not source_node_id: + continue + + if source_node_id not in edge_mapping: + edge_mapping[source_node_id] = [] + + target_node_id = edge_config.get('target') + if not target_node_id: + continue + + if target_node_id not in reverse_edge_mapping: + reverse_edge_mapping[target_node_id] = [] + + # is target node id in source node id edge mapping + if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]): + continue + + target_edge_ids.add(target_node_id) + + # parse run condition + run_condition = None + if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': + run_condition = RunCondition( + type='branch_identify', + branch_identify=edge_config.get('sourceHandle') + ) + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=target_node_id, + run_condition=run_condition + ) + + edge_mapping[source_node_id].append(graph_edge) + reverse_edge_mapping[target_node_id].append(graph_edge) + + # node configs + node_configs = graph_config.get('nodes') + if not node_configs: + raise ValueError("Graph must have at least one node") + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + all_node_id_config_mapping: dict[str, dict] = {} + for node_config in node_configs: + node_id = node_config.get('id') + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + all_node_id_config_mapping[node_id] = node_config + + root_node_ids = [node_config.get('id') for node_config in root_node_configs] + + # fetch root node + if not root_node_id: + # if no root node id, use the START type node as root node + root_node_id = next((node_config.get("id") for node_config in root_node_configs + if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + + if not root_node_id or root_node_id not in root_node_ids: + raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # Check whether it is connected to the previous node + cls._check_connected_to_previous_node( + route=[root_node_id], + edge_mapping=edge_mapping + ) + + # fetch all node ids from root node + node_ids = [root_node_id] + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=root_node_id + ) + + node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} + + # init parallel mapping + parallel_mapping: dict[str, GraphParallel] = {} + node_parallel_mapping: dict[str, str] = {} + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=root_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping + ) + + # Check if it exceeds N layers of parallel + for parallel in parallel_mapping.values(): + if parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=3, + parent_parallel_id=parallel.parent_parallel_id + ) + + # init answer stream generate routes + answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping + ) + + # init end stream param + end_stream_param = EndStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + node_parallel_mapping=node_parallel_mapping + ) + + # init graph + graph = cls( + root_node_id=root_node_id, + node_ids=node_ids, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + answer_stream_generate_routes=answer_stream_generate_routes, + end_stream_param=end_stream_param + ) + + return graph + + def add_extra_edge(self, source_node_id: str, + target_node_id: str, + run_condition: Optional[RunCondition] = None) -> None: + """ + Add extra edge to the graph + + :param source_node_id: source node id + :param target_node_id: target node id + :param run_condition: run condition + """ + if source_node_id not in self.node_ids or target_node_id not in self.node_ids: + return + + if source_node_id not in self.edge_mapping: + self.edge_mapping[source_node_id] = [] + + if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: + return + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=target_node_id, + run_condition=run_condition + ) + + self.edge_mapping[source_node_id].append(graph_edge) + + def get_leaf_node_ids(self) -> list[str]: + """ + Get leaf node ids of the graph + + :return: leaf node ids + """ + leaf_node_ids = [] + for node_id in self.node_ids: + if node_id not in self.edge_mapping: + leaf_node_ids.append(node_id) + elif (len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + leaf_node_ids.append(node_id) + + return leaf_node_ids + + @classmethod + def _recursively_add_node_ids(cls, + node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + node_id: str) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=graph_edge.target_node_id + ) + + @classmethod + def _check_connected_to_previous_node( + cls, + route: list[str], + edge_mapping: dict[str, list[GraphEdge]] + ) -> None: + """ + Check whether it is connected to the previous node + """ + last_node_id = route[-1] + + for graph_edge in edge_mapping.get(last_node_id, []): + if not graph_edge.target_node_id: + continue + + if graph_edge.target_node_id in route: + raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + + new_route = route[:] + new_route.append(graph_edge.target_node_id) + cls._check_connected_to_previous_node( + route=new_route, + edge_mapping=edge_mapping, + ) + + @classmethod + def _recursively_add_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str], + parent_parallel: Optional[GraphParallel] = None + ) -> None: + """ + Recursively add parallel ids + + :param edge_mapping: edge mapping + :param start_node_id: start from node id + :param parallel_mapping: parallel mapping + :param node_parallel_mapping: node parallel mapping + :param parent_parallel: parent parallel + """ + target_node_edges = edge_mapping.get(start_node_id, []) + parallel = None + if len(target_node_edges) > 1: + # fetch all node ids in current parallels + parallel_branch_node_ids = [] + condition_edge_mappings = {} + for graph_edge in target_node_edges: + if graph_edge.run_condition is None: + parallel_branch_node_ids.append(graph_edge.target_node_id) + else: + condition_hash = graph_edge.run_condition.hash + if not condition_hash in condition_edge_mappings: + condition_edge_mappings[condition_hash] = [] + + condition_edge_mappings[condition_hash].append(graph_edge) + + for _, graph_edges in condition_edge_mappings.items(): + if len(graph_edges) > 1: + for graph_edge in graph_edges: + parallel_branch_node_ids.append(graph_edge.target_node_id) + + # any target node id in node_parallel_mapping + if parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None + ) + parallel_mapping[parallel.id] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=parallel_branch_node_ids + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) + or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + for graph_edge in target_node_edges: + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if ( + parent_parallel_parent_parallel + and ( + not parent_parallel_parent_parallel.end_to_node_id + or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) + ) + ): + current_parallel = parent_parallel_parent_parallel + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel + ) + + @classmethod + def _check_exceed_parallel_limit( + cls, + parallel_mapping: dict[str, GraphParallel], + level_limit: int, + parent_parallel_id: str, + current_level: int = 1 + ) -> None: + """ + Check if it exceeds N layers of parallel + """ + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + return + + current_level += 1 + if current_level > level_limit: + raise ValueError(f"Exceeds {level_limit} layers of parallel") + + if parent_parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=level_limit, + parent_parallel_id=parent_parallel.parent_parallel_id, + current_level=current_level + ) + + @classmethod + def _recursively_add_parallel_node_ids(cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str) -> None: + """ + Recursively add node ids + + :param branch_node_ids: in branch node ids + :param edge_mapping: edge mapping + :param merge_node_id: merge node id + :param start_node_id: start node id + """ + for graph_edge in edge_mapping.get(start_node_id, []): + if (graph_edge.target_node_id != merge_node_id + and graph_edge.target_node_id not in branch_node_ids): + branch_node_ids.append(graph_edge.target_node_id) + cls._recursively_add_parallel_node_ids( + branch_node_ids=branch_node_ids, + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=graph_edge.target_node_id + ) + + @classmethod + def _fetch_all_node_ids_in_parallels(cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: + """ + Fetch all node ids in parallels + """ + routes_node_ids: dict[str, list[str]] = {} + for parallel_branch_node_id in parallel_branch_node_ids: + routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] + + # fetch routes node ids + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=parallel_branch_node_id, + routes_node_ids=routes_node_ids[parallel_branch_node_id] + ) + + # fetch leaf node ids from routes node ids + leaf_node_ids: dict[str, list[str]] = {} + merge_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: + if branch_node_id not in leaf_node_ids: + leaf_node_ids[branch_node_id] = [] + + leaf_node_ids[branch_node_id].append(node_id) + + for branch_node_id2, inner_route2 in routes_node_ids.items(): + if ( + branch_node_id != branch_node_id2 + and node_id in inner_route2 + and len(reverse_edge_mapping.get(node_id, [])) > 1 + and cls._is_node_in_routes( + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=node_id, + routes_node_ids=routes_node_ids + ) + ): + if node_id not in merge_branch_node_ids: + merge_branch_node_ids[node_id] = [] + + if branch_node_id2 not in merge_branch_node_ids[node_id]: + merge_branch_node_ids[node_id].append(branch_node_id2) + + # sorted merge_branch_node_ids by branch_node_ids length desc + merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + + duplicate_end_node_ids = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): + if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): + if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids + + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): + # check which node is after + if cls._is_node2_after_node1( + node1_id=node_id, + node2_id=node_id2, + edge_mapping=edge_mapping + ): + if node_id in merge_branch_node_ids: + del merge_branch_node_ids[node_id2] + elif cls._is_node2_after_node1( + node1_id=node_id2, + node2_id=node_id, + edge_mapping=edge_mapping + ): + if node_id2 in merge_branch_node_ids: + del merge_branch_node_ids[node_id] + + branches_merge_node_ids: dict[str, str] = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + if len(branch_node_ids) <= 1: + continue + + for branch_node_id in branch_node_ids: + if branch_node_id in branches_merge_node_ids: + continue + + branches_merge_node_ids[branch_node_id] = node_id + + in_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + in_branch_node_ids[branch_node_id] = [] + if branch_node_id not in branches_merge_node_ids: + # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].append(branch_node_id) + in_branch_node_ids[branch_node_id].extend(node_ids) + else: + merge_node_id = branches_merge_node_ids[branch_node_id] + if merge_node_id != branch_node_id: + in_branch_node_ids[branch_node_id].append(branch_node_id) + + # fetch all node ids from branch_node_id and merge_node_id + cls._recursively_add_parallel_node_ids( + branch_node_ids=in_branch_node_ids[branch_node_id], + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=branch_node_id + ) + + return in_branch_node_ids + + @classmethod + def _recursively_fetch_routes(cls, + edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: list[str]) -> None: + """ + Recursively fetch route + """ + if start_node_id not in edge_mapping: + return + + for graph_edge in edge_mapping[start_node_id]: + # find next node ids + if graph_edge.target_node_id not in routes_node_ids: + routes_node_ids.append(graph_edge.target_node_id) + + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=graph_edge.target_node_id, + routes_node_ids=routes_node_ids + ) + + @classmethod + def _is_node_in_routes(cls, + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: dict[str, list[str]]) -> bool: + """ + Recursively check if the node is in the routes + """ + if start_node_id not in reverse_edge_mapping: + return False + + all_routes_node_ids = set() + parallel_start_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + all_routes_node_ids.add(node_id) + + if branch_node_id in reverse_edge_mapping: + for graph_edge in reverse_edge_mapping[branch_node_id]: + if graph_edge.source_node_id not in parallel_start_node_ids: + parallel_start_node_ids[graph_edge.source_node_id] = [] + + parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) + + parallel_start_node_id = None + for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + if set(branch_node_ids) == set(routes_node_ids.keys()): + parallel_start_node_id = p_start_node_id + return True + + if not parallel_start_node_id: + raise Exception("Parallel start node id not found") + + for graph_edge in reverse_edge_mapping[start_node_id]: + if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: + return False + + return True + + @classmethod + def _is_node2_after_node1( + cls, + node1_id: str, + node2_id: str, + edge_mapping: dict[str, list[GraphEdge]] + ) -> bool: + """ + is node2 after node1 + """ + if node1_id not in edge_mapping: + return False + + for graph_edge in edge_mapping[node1_id]: + if graph_edge.target_node_id == node2_id: + return True + + if cls._is_node2_after_node1( + node1_id=graph_edge.target_node_id, + node2_id=node2_id, + edge_mapping=edge_mapping + ): + return True + + return False \ No newline at end of file diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py new file mode 100644 index 00000000000000..1a403f3e49f265 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from models.workflow import WorkflowType + + +class GraphInitParams(BaseModel): + # init params + tenant_id: str = Field(..., description="tenant / workspace id") + app_id: str = Field(..., description="app id") + workflow_type: WorkflowType = Field(..., description="workflow type") + workflow_id: str = Field(..., description="workflow id") + graph_config: Mapping[str, Any] = Field(..., description="graph config") + user_id: str = Field(..., description="user id") + user_from: UserFrom = Field(..., description="user from, account or end-user") + invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py new file mode 100644 index 00000000000000..c7d484ddf55f7c --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -0,0 +1,27 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState + + +class GraphRuntimeState(BaseModel): + variable_pool: VariablePool = Field(..., description="variable pool") + """variable pool""" + + start_at: float = Field(..., description="start time") + """start time""" + total_tokens: int = 0 + """total tokens""" + llm_usage: LLMUsage = LLMUsage.empty_usage() + """llm usage info""" + outputs: dict[str, Any] = {} + """outputs""" + + node_run_steps: int = 0 + """node run steps""" + + node_run_state: RuntimeRouteState = RuntimeRouteState() + """node run state""" diff --git a/api/core/workflow/graph_engine/entities/next_graph_node.py b/api/core/workflow/graph_engine/entities/next_graph_node.py new file mode 100644 index 00000000000000..6aa4341ddfe171 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/next_graph_node.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.graph_engine.entities.graph import GraphParallel + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py new file mode 100644 index 00000000000000..0362343568dd56 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -0,0 +1,21 @@ +import hashlib +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify like: sourceHandle, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" + + @property + def hash(self) -> str: + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py new file mode 100644 index 00000000000000..b5d6e4c09d70e2 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -0,0 +1,111 @@ +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus + + +class RouteNodeState(BaseModel): + class Status(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """node state id""" + + node_id: str + """node id""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.RUNNING + """node status""" + + start_at: datetime + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + index: int = 1 + + def set_finished(self, run_result: NodeRunResult) -> None: + """ + Node finished + + :param run_result: run result + """ + if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + raise Exception(f"Route state {self.id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + self.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + self.status = RouteNodeState.Status.FAILED + self.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + self.node_run_result = run_result + self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + +class RuntimeRouteState(BaseModel): + routes: dict[str, list[str]] = Field( + default_factory=dict, + description="graph state routes (source_node_state_id: target_node_state_id)" + ) + + node_state_mapping: dict[str, RouteNodeState] = Field( + default_factory=dict, + description="node state mapping (route_node_state_id: route_node_state)" + ) + + def create_node_state(self, node_id: str) -> RouteNodeState: + """ + Create node state + + :param node_id: node id + """ + state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + self.node_state_mapping[state.id] = state + return state + + def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + """ + Add route to the graph state + + :param source_node_state_id: source node state id + :param target_node_state_id: target node state id + """ + if source_node_state_id not in self.routes: + self.routes[source_node_state_id] = [] + + self.routes[source_node_state_id].append(target_node_state_id) + + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ + -> list[RouteNodeState]: + """ + Get routes with node state by source node id + + :param source_node_state_id: source node state id + :return: routes with node state + """ + return [self.node_state_mapping[target_state_id] + for target_state_id in self.routes.get(source_node_state_id, [])] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py new file mode 100644 index 00000000000000..65d9ab8446ef1c --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -0,0 +1,716 @@ +import logging +import queue +import time +import uuid +from collections.abc import Generator, Mapping +from concurrent.futures import ThreadPoolExecutor, wait +from typing import Any, Optional + +from flask import Flask, current_app + +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeType, + UserFrom, +) +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.event import ( + BaseIterationEvent, + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph, GraphEdge +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.node_mapping import node_classes +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + +logger = logging.getLogger(__name__) + + +class GraphEngineThreadPool(ThreadPoolExecutor): + def __init__(self, max_workers=None, thread_name_prefix='', + initializer=None, initargs=(), max_submit_count=100) -> None: + super().__init__(max_workers, thread_name_prefix, initializer, initargs) + self.max_submit_count = max_submit_count + self.submit_count = 0 + + def submit(self, fn, *args, **kwargs): + self.submit_count += 1 + self.check_is_full() + + return super().submit(fn, *args, **kwargs) + + def check_is_full(self) -> None: + print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") + if self.submit_count > self.max_submit_count: + raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") + + +class GraphEngine: + workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} + + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None + ) -> None: + thread_pool_max_submit_count = 100 + thread_pool_max_workers = 10 + + ## init thread pool + if thread_pool_id: + if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") + + self.thread_pool_id = thread_pool_id + self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] + self.is_main_thread_pool = False + else: + self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool_id = str(uuid.uuid4()) + self.is_main_thread_pool = True + GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool + + self.graph = graph + self.init_params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth + ) + + self.graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter() + ) + + self.max_execution_steps = max_execution_steps + self.max_execution_time = max_execution_time + + def run(self) -> Generator[GraphEngineEvent, None, None]: + # trigger graph run start event + yield GraphRunStartedEvent() + + try: + stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] + if self.init_params.workflow_type == WorkflowType.CHAT: + stream_processor_cls = AnswerStreamProcessor + else: + stream_processor_cls = EndStreamProcessor + + stream_processor = stream_processor_cls( + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool + ) + + # run graph + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id) + ) + + for item in generator: + try: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') + return + elif isinstance(item, NodeRunSucceededEvent): + if item.node_type == NodeType.END: + self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {}) + elif item.node_type == NodeType.ANSWER: + if "answer" not in self.graph_runtime_state.outputs: + self.graph_runtime_state.outputs["answer"] = "" + + self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "") + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() + except Exception as e: + logger.exception(f"Graph run failed: {str(e)}") + yield GraphRunFailedEvent(error=str(e)) + return + + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + except GraphRunFailedError as e: + yield GraphRunFailedEvent(error=e.error) + return + except Exception as e: + logger.exception("Unknown Error when graph running") + yield GraphRunFailedEvent(error=str(e)) + raise e + finally: + if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] + + def _run( + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None + ) -> Generator[GraphEngineEvent, None, None]: + parallel_start_node_id = None + if in_parallel_id: + parallel_start_node_id = start_node_id + + next_node_id = start_node_id + previous_route_node_state: Optional[RouteNodeState] = None + while True: + # max steps reached + if self.graph_runtime_state.node_run_steps > self.max_execution_steps: + raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + + # or max execution time reached + if self._is_timed_out( + start_at=self.graph_runtime_state.start_at, + max_execution_time=self.max_execution_time + ): + raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + + # init route node state + route_node_state = self.graph_runtime_state.node_run_state.create_node_state( + node_id=next_node_id + ) + + # get node config + node_id = route_node_state.node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f'Node {node_id} config not found.') + + # convert to specific node + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + + previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None + + # init workflow run state + node_instance = node_cls( # type: ignore + id=route_node_state.id, + config=node_config, + graph_init_params=self.init_params, + graph=self.graph, + graph_runtime_state=self.graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=self.thread_pool_id + ) + + try: + # run node + generator = self._run_node( + node_instance=node_instance, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + for item in generator: + if isinstance(item, NodeRunStartedEvent): + self.graph_runtime_state.node_run_steps += 1 + item.route_node_state.index = self.graph_runtime_state.node_run_steps + + yield item + + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state + + # append route + if previous_route_node_state: + self.graph_runtime_state.node_run_state.add_route( + source_node_state_id=previous_route_node_state.id, + target_node_state_id=route_node_state.id + ) + except Exception as e: + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = str(e) + yield NodeRunFailedEvent( + error=str(e), + id=node_instance.id, + node_id=next_node_id, + node_type=node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + raise e + + # It may not be necessary, but it is necessary. :) + if (self.graph.node_id_config_mapping[next_node_id] + .get("data", {}).get("type", "").lower() == NodeType.END.value): + break + + previous_route_node_state = route_node_state + + # get next node ids + edge_mappings = self.graph.edge_mapping.get(next_node_id) + if not edge_mappings: + break + + if len(edge_mappings) == 1: + edge = edge_mappings[0] + + if edge.run_condition: + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state + ) + + if not result: + break + + next_node_id = edge.target_node_id + else: + final_node_id = None + + if any(edge.run_condition for edge in edge_mappings): + # if nodes has run conditions, get node id which branch to take based on the run condition results + condition_edge_mappings = {} + for edge in edge_mappings: + if edge.run_condition: + run_condition_hash = edge.run_condition.hash + if run_condition_hash not in condition_edge_mappings: + condition_edge_mappings[run_condition_hash] = [] + + condition_edge_mappings[run_condition_hash].append(edge) + + for _, sub_edge_mappings in condition_edge_mappings.items(): + if len(sub_edge_mappings) == 0: + continue + + edge = sub_edge_mappings[0] + + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + continue + + if len(sub_edge_mappings) == 1: + final_node_id = edge.target_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=sub_edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + break + + if not final_node_id: + break + + next_node_id = final_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + if not final_node_id: + break + + next_node_id = final_node_id + + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: + break + + def _run_parallel_branches( + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent | str, None, None]: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) + if not parallel_id: + node_id = edge_mappings[0].target_node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') + + node_title = node_config.get('data', {}).get('title') + raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # Create a list to store the threads + futures = [] + + # new thread + for edge in edge_mappings: + if ( + edge.target_node_id not in self.graph.node_parallel_mapping + or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + ): + continue + + futures.append( + self.thread_pool.submit(self._run_parallel_node, **{ + 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] + 'q': q, + 'parallel_id': parallel_id, + 'parallel_start_node_id': edge.target_node_id, + 'parent_parallel_id': in_parallel_id, + 'parent_parallel_start_node_id': parallel_start_node_id, + }) + ) + + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue + + # wait all threads + wait(futures) + + # get final node id + final_node_id = parallel.end_to_node_id + if final_node_id: + yield final_node_id + + def _run_parallel_node( + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> None: + """ + Run parallel nodes + """ + with flask_app.app_context(): + try: + q.put(ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + )) + + # run node + generator = self._run( + start_node_id=parallel_start_node_id, + in_parallel_id=parallel_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + for item in generator: + q.put(item) + + # trigger graph run success event + q.put(ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + )) + except GraphRunFailedError as e: + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error + )) + except Exception as e: + logger.exception("Unknown Error when generating in parallel") + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e) + )) + finally: + db.session.remove() + + def _run_node( + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + """ + Run node + """ + # trigger node run start event + yield NodeRunStartedEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + db.session.close() + + try: + # run node + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or 'Unknown error.', + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) + + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + except GenerateTaskStoppedException: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") + raise e + finally: + db.session.close() + + def _append_variables_recursively(self, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + self.graph_runtime_state.variable_pool.add( + [node_id] + variable_key_list, + variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + +class GraphRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 5bae27092f920d..8cf01727ec9296 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,9 +1,8 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, @@ -19,24 +18,26 @@ class AnswerNode(BaseNode): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(AnswerNodeData, node_data) # generate routes - generate_routes = self.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) answer = '' for part in generate_routes: - if part.type == "var": + if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = variable_pool.get(value_selector) + value = self.graph_runtime_state.variable_pool.get( + value_selector + ) + if value: answer += value.markdown else: @@ -51,70 +52,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(AnswerNodeData, node_data) - - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes - - @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -126,6 +73,6 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py new file mode 100644 index 00000000000000..6cb80091c91fbd --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -0,0 +1,169 @@ + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} + for answer_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + answer_generate_route[answer_node_id] = generate_route + + # fetch answer dependencies + answer_node_ids = list(answer_generate_route.keys()) + answer_dependencies = cls._fetch_answers_dependencies( + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return AnswerStreamGenerateRoute( + answer_generate_route=answer_generate_route, + answer_dependencies=answer_dependencies + ) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + + @classmethod + def _fetch_answers_dependencies(cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch answer dependencies + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + answer_dependencies: dict[str, list[str]] = {} + for answer_node_id in answer_node_ids: + if answer_dependencies.get(answer_node_id) is None: + answer_dependencies[answer_node_id] = [] + + cls._recursive_fetch_answer_dependencies( + current_node_id=answer_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) + + return answer_dependencies + + @classmethod + def _recursive_fetch_answer_dependencies(cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._recursive_fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py new file mode 100644 index 00000000000000..c2a5dd5163819d --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -0,0 +1,221 @@ +import logging +from collections.abc import Generator +from typing import Optional, cast + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class AnswerStreamProcessor(StreamProcessor): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.generate_routes = graph.answer_stream_generate_routes + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_answer_node_ids + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if (event.route_node_state.node_id != answer_node_id + and (answer_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, route_position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids + + @classmethod + def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file_var = cls._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = cls._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + @classmethod + def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict): + if '__variant' in value and value['__variant'] == FileVar.__name__: + return value + elif isinstance(value, FileVar): + return value.to_dict() + + return None diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py new file mode 100644 index 00000000000000..cbabbca37d0558 --- /dev/null +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.graph import Graph + + +class StreamProcessor(ABC): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.rest_node_ids = graph.node_ids.copy() + + @abstractmethod + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + raise NotImplementedError + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if (edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + if node_id not in self.rest_node_ids: + return + + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 9effbbbe671420..620c2c426bae9e 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,6 @@ +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - answer: str + answer: str = Field(..., description="answer template string") class GenerateRouteChunk(BaseModel): """ Generate Route Chunk. """ - type: str + + class ChunkType(Enum): + VAR = "var" + TEXT = "text" + + type: ChunkType = Field(..., description="generate route chunk type") class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ - type: str = "var" - value_selector: list[str] + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR + """generate route chunk type""" + value_selector: list[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ - type: str = "text" - text: str + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT + """generate route chunk type""" + text: str = Field(..., description="text") + + +class AnswerNodeDoubleLink(BaseModel): + node_id: str = Field(..., description="node id") + source_node_ids: list[str] = Field(..., description="source node ids") + target_node_ids: list[str] = Field(..., description="target node ids") + + +class AnswerStreamGenerateRoute(BaseModel): + """ + AnswerStreamGenerateRoute entity + """ + answer_dependencies: dict[str, list[str]] = Field( + ..., + description="answer dependencies (answer node id -> dependent answer node ids)" + ) + answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( + ..., + description="answer generate route (answer node id -> generate route chunks)" + ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3d9cf52771e146..b9912314f136f3 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,142 +1,103 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from enum import Enum +from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from models import WorkflowNodeExecutionStatus - - -class UserFrom(Enum): - """ - User from - """ - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent, RunEvent class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - tenant_id: str - app_id: str - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - node_id: str - node_data: BaseNodeData - node_run_result: Optional[NodeRunResult] = None - - callbacks: Sequence[WorkflowCallback] - - is_answer_previous_node: bool = False - - def __init__(self, tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, + def __init__(self, + id: str, config: Mapping[str, Any], - callbacks: Sequence[WorkflowCallback] | None = None, - workflow_call_depth: int = 0) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - self.workflow_id = workflow_id - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - # TODO: May need to check if key exists. - self.node_id = config["id"] - if not self.node_id: + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None) -> None: + self.id = id + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id + + node_id = config.get("id") + if not node_id: raise ValueError("Node ID is required.") + self.node_id = node_id self.node_data = self._node_data_cls(**config.get("data", {})) - self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) \ + -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ raise NotImplementedError - def run(self, variable_pool: VariablePool) -> NodeRunResult: + def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node entry - :param variable_pool: variable pool :return: """ - try: - result = self._run( - variable_pool=variable_pool - ) - self.node_run_result = result - return result - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + result = self._run() - def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: - """ - Publish text chunk - :param text: chunk text - :param value_selector: value selector - :return: - """ - if self.callbacks: - for callback in self.callbacks: - callback.on_node_text_chunk( - node_id=self.node_id, - text=text, - metadata={ - "node_type": self.node_type, - "is_answer_previous_node": self.is_answer_previous_node, - "value_selector": value_selector - } - ) + if isinstance(result, NodeRunResult): + yield RunCompletedEvent( + run_result=result + ) + else: + yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict): + def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config :param config: node config :return: """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping(node_data) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + node_id=node_id, + node_data=node_data + ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: BaseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -158,38 +119,3 @@ def node_type(self) -> NodeType: :return: """ return self._node_type - -class BaseIterationNode(BaseNode): - @abstractmethod - def _run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node - :param variable_pool: variable pool - :return: - """ - raise NotImplementedError - - def run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node entry - :param variable_pool: variable pool - :return: - """ - return self._run(variable_pool=variable_pool) - - def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - return self._get_next_iteration(variable_pool, state) - - @abstractmethod - def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - raise NotImplementedError diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 335991ae87c285..955afdfa1dd4a1 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage @@ -6,7 +7,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -33,13 +33,13 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: return code_provider.get_default_config() - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run code - :param variable_pool: variable pool :return: """ - node_data = cast(CodeNodeData, self.node_data) + node_data = self.node_data + node_data = cast(CodeNodeData, node_data) # Get code language code_language = node_data.code_language @@ -49,7 +49,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable] = value # Run code @@ -311,13 +311,19 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 440dfa2f2710f3..552914b308a9b7 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,10 +11,9 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData _node_type = NodeType.END - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data @@ -24,7 +22,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: outputs = {} for variable_selector in output_variables: - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value return NodeRunResult( @@ -34,52 +32,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: - """ - Extract generate nodes - :param graph: graph - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(EndNodeData, node_data) - - return cls.extract_generate_nodes_from_node_data(graph, node_data) - - @classmethod - def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: - """ - Extract generate nodes from node data - :param graph: graph - :param node_data: node data object - :return: - """ - nodes = graph.get('nodes', []) - node_mapping = {node.get('id'): node for node in nodes} - - variable_selectors = node_data.outputs - - generate_nodes = [] - for variable_selector in variable_selectors: - if not variable_selector.value_selector: - continue - - node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_mapping: - node = node_mapping[node_id] - node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': - generate_nodes.append(node_id) - - # remove duplicates - generate_nodes = list(set(generate_nodes)) - - return generate_nodes - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py new file mode 100644 index 00000000000000..8390f6d81b4d31 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -0,0 +1,148 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam + + +class EndStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str] + ) -> EndStreamParam: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selector of end nodes + end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} + for end_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.END.value: + continue + + # skip end node in parallel + if end_node_id in node_parallel_mapping: + continue + + # get generate route for stream output + stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) + end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors + + # fetch end dependencies + end_node_ids = list(end_stream_variable_selectors_mapping.keys()) + end_dependencies = cls._fetch_ends_dependencies( + end_node_ids=end_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return EndStreamParam( + end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, + end_dependencies=end_dependencies + ) + + @classmethod + def extract_stream_variable_selector_from_node_data(cls, + node_id_config_mapping: dict[str, dict], + node_data: EndNodeData) -> list[list[str]]: + """ + Extract stream variable selector from node data + :param node_id_config_mapping: node id config mapping + :param node_data: node data object + :return: + """ + variable_selectors = node_data.outputs + + value_selectors = [] + for variable_selector in variable_selectors: + if not variable_selector.value_selector: + continue + + node_id = variable_selector.value_selector[0] + if node_id != 'sys' and node_id in node_id_config_mapping: + node = node_id_config_mapping[node_id] + node_type = node.get('data', {}).get('type') + if ( + variable_selector.value_selector not in value_selectors + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == 'text' + ): + value_selectors.append(variable_selector.value_selector) + + return value_selectors + + @classmethod + def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ + -> list[list[str]]: + """ + Extract stream variable selector from node config + :param node_id_config_mapping: node id config mapping + :param config: node config + :return: + """ + node_data = EndNodeData(**config.get("data", {})) + return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) + + @classmethod + def _fetch_ends_dependencies(cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch end dependencies + :param end_node_ids: end node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + end_dependencies: dict[str, list[str]] = {} + for end_node_id in end_node_ids: + if end_dependencies.get(end_node_id) is None: + end_dependencies[end_node_id] = [] + + cls._recursive_fetch_end_dependencies( + current_node_id=end_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) + + return end_dependencies + + @classmethod + def _recursive_fetch_end_dependencies(cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch end dependencies + :param current_node_id: current node id + :param end_node_id: end node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param end_dependencies: end dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + end_dependencies[end_node_id].append(source_node_id) + else: + cls._recursive_fetch_end_dependencies( + current_node_id=source_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py new file mode 100644 index 00000000000000..4474c2a78a2740 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -0,0 +1,191 @@ +import logging +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +class EndStreamProcessor(StreamProcessor): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.end_stream_param = graph.end_stream_param + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + self.has_outputed = False + self.outputed_node_ids = set() + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_end_node_ids + + if stream_out_end_node_ids: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[end_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for end_node_id, position in self.route_position.items(): + # all depends on end node id not in rest node ids + if (event.route_node_state.node_id != end_node_id + and (end_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + continue + + route_position = self.route_position[end_node_id] + + position = 0 + value_selectors = [] + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position >= route_position: + value_selectors.append(current_value_selectors) + + position += 1 + + for value_selector in value_selectors: + if not value_selector: + continue + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + current_node_id = value_selector[0] + if self.has_outputed and current_node_id not in self.outputed_node_ids: + text = '\n' + text + + self.outputed_node_ids.add(current_node_id) + self.has_outputed = True + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[end_node_id] += 1 + + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_end_node_ids = [] + for end_node_id, route_position in self.route_position.items(): + if end_node_id not in self.rest_node_ids: + continue + + # all depends on end node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): + continue + + position = 0 + value_selector = None + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position == route_position: + value_selector = current_value_selectors + break + + position += 1 + + if not value_selector: + continue + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_end_node_ids.append(end_node_id) + + return stream_out_end_node_ids diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index ad4fc8f04fd43c..a0edf7b5796bba 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel, Field + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData): END Node Data. """ outputs: list[VariableSelector] + + +class EndStreamParam(BaseModel): + """ + EndStreamParam entity + """ + end_dependencies: dict[str, list[str]] = Field( + ..., + description="end dependencies (end node id -> dependent node ids)" + ) + end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( + ..., + description="end stream variable selector mapping (end node id -> stream variable selectors)" + ) diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event.py new file mode 100644 index 00000000000000..276c13a6d4f490 --- /dev/null +++ b/api/core/workflow/nodes/event.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 037a7a1848ccfa..3f68c8b1d03ce5 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,15 +1,14 @@ import logging +from collections.abc import Mapping, Sequence from mimetypes import guess_extension from os import path -from typing import cast +from typing import Any, cast from configs import dify_config from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( HttpRequestNodeData, @@ -48,17 +47,22 @@ def get_default_config(cls, filters: dict | None = None) -> dict: }, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text + node_data.authorization.config.api_key = parser.convert_template( + template=node_data.authorization.config.api_key, + variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None try: http_executor = HttpExecutor( - node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool + node_data=node_data, + timeout=self._get_request_timeout(node_data), + variable_pool=self.graph_runtime_state.variable_pool ) # invoke http executor @@ -102,13 +106,19 @@ def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeo return timeout @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = cast(HttpRequestNodeData, node_data) try: http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) @@ -116,7 +126,7 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 7eb69b80dfe117..338277ace1a150 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -3,20 +3,7 @@ from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match", - # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" - ] - value: Optional[str] = None +from core.workflow.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 2b253764b7f090..ca87eecd0d4c5f 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,13 +1,10 @@ -import re -from collections.abc import Sequence -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData -from core.workflow.utils.variable_template_parser import VariableTemplateParser +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus @@ -15,31 +12,35 @@ class IfElseNode(BaseNode): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs = { + node_inputs: dict[str, list] = { "conditions": [] } - process_datas = { + process_datas: dict[str, list] = { "condition_results": [] } input_conditions = [] final_result = False selected_case_id = None + condition_processor = ConditionProcessor() try: # Check if the new cases structure is used if node_data.cases: for case in node_data.cases: - input_conditions, group_result = self.process_conditions(variable_pool, case.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions + ) + # Apply the logical operator for the current case final_result = all(group_result) if case.logical_operator == "and" else any(group_result) @@ -58,7 +59,10 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: else: # Fallback to old structure if cases are not defined - input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=node_data.conditions + ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) @@ -94,376 +98,17 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return data - def evaluate_condition( - self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator - - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - elif comparison_operator == "regex match": - return self._assert_regex_match(actual_value, expected_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") - - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - for condition in conditions: - actual_variable = variable_pool.get_any(condition.variable_selector) - - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - expected_value = variable_template_parser.extract_variable_selectors() - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - else: - expected_value = condition.value - else: - expected_value = None - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_variable, - "expected_value": expected_value, - "comparison_operator": comparison_operator - } - ) - - result = self.evaluate_condition(actual_variable, expected_value, comparison_operator) - group_result.append(result) - - return input_conditions, group_result - - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value not in actual_value: - return False - return True - - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True - return False - - def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if actual_value is None: - return False - return re.search(expected_value, actual_value) is not None - - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True - return False - - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True - return False - - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True - return False - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 177b47b9518e00..5fc5a827aedd08 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): @@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData): iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + +class IterationStartNodeData(BaseNodeData): + """ + Iteration Start Node Data. + """ + pass + class IterationState(BaseIterationState): """ Iteration State. diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 54dfe8b7f4e40e..93eff16c339558 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,124 +1,371 @@ -from typing import cast +import logging +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime, timezone +from typing import Any, cast +from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode -from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.graph_engine.entities.event import ( + BaseGraphEvent, + BaseNodeEvent, + BaseParallelBranchEvent, + GraphRunFailedEvent, + InNodeEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) -class IterationNode(BaseIterationNode): + +class IterationNode(BaseNode): """ Iteration Node. """ _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run the node. """ self.node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(self.node_data.iterator_selector) - - if not isinstance(iterator, list): - raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") + iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ - 'iterator_selector': iterator - }, outputs=[], metadata=IterationState.MetaData( - iterator_length=len(iterator) if iterator is not None else 0 - )) + if not iterator_list_segment: + raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - self._set_current_iteration_variable(variable_pool, state) - return state + iterator_list_value = iterator_list_segment.to_object() - def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - # resolve current output - self._resolve_current_output(variable_pool, state) - # move to next iteration - self._next_iteration(variable_pool, state) - - node_data = cast(IterationNodeData, self.node_data) - if self._reached_iteration_limit(variable_pool, state): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, + if not isinstance(iterator_list_value, list): + raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + + inputs = { + "iterator_selector": iterator_list_value + } + + graph_config = self.graph_config + + if not self.node_data.start_node_id: + raise ValueError(f'field start_node_id in iteration {self.node_id} not found') + + root_node_id = self.node_data.start_node_id + + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=root_node_id + ) + + if not iteration_graph: + raise ValueError('iteration graph not found') + + leaf_node_ids = iteration_graph.get_leaf_node_ids() + iteration_leaf_node_ids = [] + for leaf_node_id in leaf_node_ids: + node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) + if not node_config: + continue + + leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") + if not leaf_node_iteration_id: + continue + + if leaf_node_iteration_id != self.node_id: + continue + + iteration_leaf_node_ids.append(leaf_node_id) + + # add condition of end nodes to root node + iteration_graph.add_extra_edge( + source_node_id=leaf_node_id, + target_node_id=root_node_id, + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=[self.node_id, "index"], + comparison_operator="<", + value=str(len(iterator_list_value)) + ) + ] + ) + ) + + variable_pool = self.graph_runtime_state.variable_pool + + # append iteration variable (item, index) to variable pool + variable_pool.add( + [self.node_id, 'index'], + 0 + ) + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[0] + ) + + # init graph engine + from core.workflow.graph_engine.graph_engine import GraphEngine + graph_engine = GraphEngine( + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_type=self.workflow_type, + workflow_id=self.workflow_id, + user_id=self.user_id, + user_from=self.user_from, + invoke_from=self.invoke_from, + call_depth=self.workflow_call_depth, + graph=iteration_graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + ) + + start_at = datetime.now(timezone.utc).replace(tzinfo=None) + + yield IterationRunStartedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + metadata={ + "iterator_length": len(iterator_list_value) + }, + predecessor_node_id=self.previous_node_id + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=0, + pre_iteration_output=None + ) + + outputs: list[Any] = [] + try: + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: + continue + + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) + event.route_node_state.node_run_result.metadata = metadata + + yield event + + # handle iteration run result + if event.route_node_state.node_id in iteration_leaf_node_ids: + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, 'index']) + if current_index is None: + raise ValueError(f'iteration {self.node_id} current index not found') + + next_index = int(current_index.to_object()) + 1 + variable_pool.add( + [self.node_id, 'index'], + next_index + ) + + if next_index < len(iterator_list_value): + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[next_index] + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder( + current_iteration_output) if current_iteration_output else None + ) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={ + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens + }, + error=event.error, + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + break + else: + event = cast(InNodeEvent, event) + yield event + + yield IterationRunSucceededEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, outputs={ - 'output': jsonable_encoder(state.outputs) + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens } ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'output': jsonable_encoder(outputs) + } + ) + ) + except Exception as e: + # iteration run failed + logger.exception("Iteration run failed") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={ + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens + }, + error=str(e), + ) - return node_data.start_node_id + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + finally: + # remove iteration variable (item, index) from variable pool after iteration run completed + variable_pool.remove([self.node_id, 'index']) + variable_pool.remove([self.node_id, 'item']) - def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: """ - Set current iteration variable. - :variable_pool: variable pool + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: """ - node_data = cast(IterationNodeData, self.node_data) + variable_mapping = { + f'{node_id}.input_selector': node_data.iterator_selector, + } - variable_pool.add((self.node_id, 'index'), state.index) - # get the iterator value - iterator = variable_pool.get_any(node_data.iterator_selector) + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=node_data.start_node_id + ) - if iterator is None or not isinstance(iterator, list): - return + if not iteration_graph: + raise ValueError('iteration graph not found') - if state.index < len(iterator): - variable_pool.add((self.node_id, 'item'), iterator[state.index]) + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): + if sub_node_config.get('data', {}).get('iteration_id') != node_id: + continue - def _next_iteration(self, variable_pool: VariablePool, state: IterationState): - """ - Move to next iteration. - :param variable_pool: variable pool - """ - state.index += 1 - self._set_current_iteration_variable(variable_pool, state) + # variable selector to variable mapping + try: + # Get node class + from core.workflow.nodes.node_mapping import node_classes + node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + continue - def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): - """ - Check if iteration limit is reached. - :return: True if iteration limit is reached, False otherwise - """ - node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(node_data.iterator_selector) + node_cls = cast(BaseNode, node_cls) + + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + config=sub_node_config + ) + sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + except NotImplementedError: + sub_node_variable_mapping = {} - if iterator is None or not isinstance(iterator, list): - return True + # remove iteration variables + sub_node_variable_mapping = { + sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + if value[0] != node_id + } - return state.index >= len(iterator) - - def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): - """ - Resolve current output. - :param variable_pool: variable pool - """ - output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_any(output_selector) - # clear the output for this iteration - variable_pool.remove([self.node_id] + output_selector[1:]) - state.current_output = output - if output is not None: - # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). - if isinstance(output, list): - state.outputs.extend(output) - else: - state.outputs.append(output) + variable_mapping.update(sub_node_variable_mapping) - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ - return { - 'input_selector': node_data.iterator_selector, - } \ No newline at end of file + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() + if value[0] not in iteration_graph.node_ids + } + + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py new file mode 100644 index 00000000000000..25044cf3eb99b2 --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationStartNode(BaseNode): + """ + Iteration Start Node. + """ + _node_data_cls = IterationStartNodeData + _node_type = NodeType.ITERATION_START + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 6c052c0d6bb6a1..2d1ac4731c269f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,3 +1,5 @@ +import logging +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import func @@ -12,15 +14,15 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) + default_retrieval_model = { 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, @@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData node_type = NodeType.KNOWLEDGE_RETRIEVAL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + def _run(self) -> NodeRunResult: + node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = variable_pool.get_any(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable variables = { 'query': query @@ -68,7 +70,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) except Exception as e: - + logger.exception("Error when running knowledge retrieval node") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -235,11 +237,21 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: return context_list @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {} - variable_mapping['query'] = node_data.query_variable_selector + variable_mapping[node_id + '.query'] = node_data.query_variable_selector return variable_mapping def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 737b1af1434716..f26ec1b0b54269 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,16 +1,17 @@ import json -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast + +from pydantic import BaseModel from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, @@ -25,7 +26,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -43,17 +46,26 @@ +class ModelInvokeCompleted(BaseModel): + """ + Model invoke completed + """ + text: str + usage: LLMUsage + finish_reason: Optional[str] = None + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ node_data = cast(LLMNodeData, deepcopy(self.node_data)) + variable_pool = self.graph_runtime_state.variable_pool node_inputs = None process_data = None @@ -80,10 +92,15 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_inputs['#files#'] = [file.to_dict() for file in files] # fetch context value - context = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data, variable_pool) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield event if context: - node_inputs['#context#'] = context + node_inputs['#context#'] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -115,19 +132,34 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) ) + return outputs = { 'text': result_text, @@ -135,22 +167,26 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: 'finish_reason': finish_reason } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + }, + llm_usage=usage + ) ) def _invoke_llm(self, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: + stop: Optional[list[str]] = None) \ + -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -170,23 +206,31 @@ def _invoke_llm(self, node_data_model: ModelConfig, ) # handle invoke result - text, usage, finish_reason = self._handle_invoke_result( + generator = self._handle_invoke_result( invoke_result=invoke_result ) + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompleted): + usage = event.usage + # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage, finish_reason - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ + -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result :return: """ + if isinstance(invoke_result, LLMResult): + return + model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = '' usage = None finish_reason = None @@ -194,7 +238,10 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + yield RunStreamChunkEvent( + chunk_content=text, + from_variable_selector=[self.node_id, 'text'] + ) if not model: model = result.model @@ -211,11 +258,15 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage if not usage: usage = LLMUsage.empty_usage() - return full_text, usage, finish_reason + yield ModelInvokeCompleted( + text=full_text, + usage=usage, + finish_reason=finish_reason + ) def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -224,13 +275,13 @@ def _transform_chat_messages(self, """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2': + if messages.edition_type == 'jinja2' and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2': + if message.edition_type == 'jinja2' and message.jinja2_text: message.text = message.jinja2_text return messages @@ -348,7 +399,7 @@ def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> l return files - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: """ Fetch context :param node_data: node data @@ -356,15 +407,18 @@ def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> :return: """ if not node_data.context.enabled: - return None + return if not node_data.context.variable_selector: - return None + return context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - return context_value + yield RunRetrieverResourceEvent( + retriever_resources=[], + context=context_value + ) elif isinstance(context_value, list): context_str = '' original_retriever_resource = [] @@ -381,17 +435,10 @@ def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> if retriever_resource: original_retriever_resource.append(retriever_resource) - if self.callbacks and original_retriever_resource: - for callback in self.callbacks: - callback.on_event( - event=QueueRetrieverResourcesEvent( - retriever_resources=original_retriever_resource - ) - ) - - return context_str.strip() - - return None + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, + context=context_str.strip() + ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: """ @@ -574,7 +621,8 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): + if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( + content_item, ImagePromptMessageContent): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -646,13 +694,19 @@ def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - prompt_template = node_data.prompt_template variable_selectors = [] @@ -702,6 +756,10 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 7d53c6f5f2c32e..526404e30d0d37 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,20 +1,34 @@ -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode +from typing import Any + +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.loop.entities import LoopNodeData, LoopState +from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseIterationNode): +class LoopNode(BaseNode): """ Loop Node. """ _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self, variable_pool: VariablePool) -> LoopState: - return super()._run(variable_pool) + def _run(self) -> LoopState: + return super()._run() - def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ - Get next iteration start node id based on the graph. + Get conditions. """ + node_id = node_config.get('id') + if not node_id: + return [] + + # TODO waiting for implementation + return [Condition( + variable_selector=[node_id, 'index'], + comparison_operator="≤", + value_type="value_selector", + value_selector=[] + )] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py new file mode 100644 index 00000000000000..b98525e86e91f1 --- /dev/null +++ b/api/core/workflow/nodes/node_mapping.py @@ -0,0 +1,37 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.ITERATION_START: IterationStartNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, +} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2876695a825ba0..2e65705f10f2ce 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,6 +1,7 @@ import json import uuid -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -66,12 +67,12 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = variable_pool.get_any(node_data.query) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) if not variable: raise ValueError("Input variable content not found or is empty") query = variable @@ -92,17 +93,20 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ and node_data.reasoning_mode == 'function_call': # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, variable_pool, model_config, memory + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, + prompt_messages = self._generate_prompt_engineering_prompt(node_data, + query, + self.graph_runtime_state.variable_pool, + model_config, memory) prompt_message_tools = [] @@ -172,7 +176,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) def _invoke_llm(self, node_data_model: ModelConfig, @@ -697,15 +702,19 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ return self._model_instance, self._model_config @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - variable_mapping = { 'query': node_data.query } @@ -715,4 +724,8 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtr for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index f4057d50f38988..ecab8db9b62b56 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,12 @@ import json import logging -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder @@ -13,10 +15,9 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, @@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData node_type = NodeType.QUESTION_CLASSIFIER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data = cast(QuestionClassifierNodeData, node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables variable = variable_pool.get(node_data.query_variable_selector) @@ -63,12 +65,23 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + category_name = node_data.classes[0].name category_id = node_data.classes[0].id try: @@ -109,7 +122,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) except ValueError as e: @@ -121,13 +135,24 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {'query': node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: @@ -135,6 +160,11 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 54e66bd6718901..69cdec6a9260a3 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,9 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData +from collections.abc import Mapping, Sequence +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -11,14 +13,13 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData _node_type = NodeType.START - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_inputs = dict(variable_pool.user_inputs) - system_inputs = variable_pool.system_variables + node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + system_inputs = self.graph_runtime_state.variable_pool.system_variables for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] @@ -30,9 +31,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 21f71db6c549aa..b14a394a0a8cfd 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,15 +1,16 @@ import os -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) + class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -34,7 +35,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node """ @@ -45,7 +46,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: variables = {} for variable_selector in node_data.variables: variable_name = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable_name] = value # Run code try: @@ -60,7 +61,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: status=WorkflowNodeExecutionStatus.FAILED, error=str(e) ) - + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, @@ -75,14 +76,21 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: 'output': result['result'] } ) - + @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } \ No newline at end of file + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ccce9ef3607624..feedeb6dad3824 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -26,7 +26,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the tool node """ @@ -56,8 +56,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) + parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) try: messages = ToolEngine.workflow_invoke( @@ -66,6 +66,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, ) except Exception as e: return NodeRunResult( @@ -145,7 +146,8 @@ def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): + def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ + -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -221,9 +223,16 @@ def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -239,4 +248,8 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) elif input.type == 'constant': pass + result = { + node_id + '.' + key: value for key, value in result.items() + } + return result diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 885f7d76170f94..6944d9e82dcd41 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} @@ -20,7 +19,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: for selector in node_data.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs = { "output": variable @@ -33,7 +32,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: else: for group in node_data.advanced_settings.groups: for selector in group.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs[group.group_name] = { @@ -49,5 +48,17 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ return {} diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 8c2adcabb9dc05..b2f32c6aaaea20 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -6,7 +6,6 @@ from core.app.segments import SegmentType, Variable, factory from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from extensions.ext_database import db from models import ConversationVariable, WorkflowNodeExecutionStatus @@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: data = cast(VariableAssignerData, self.node_data) # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableAssignerNodeError('assigned variable not found') match data.write_mode: case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_variable = original_variable.model_copy(update={'value': income_value.value}) case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_value = original_variable.value + [income_value.value] @@ -49,11 +48,11 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) if not conversation_id: raise VariableAssignerNodeError('conversation_id not found') update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) diff --git a/api/core/workflow/utils/condition/__init__.py b/api/core/workflow/utils/condition/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py new file mode 100644 index 00000000000000..e195730a31ebb5 --- /dev/null +++ b/api/core/workflow/utils/condition/entities.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py new file mode 100644 index 00000000000000..5ff61aab3d79ec --- /dev/null +++ b/api/core/workflow/utils/condition/processor.py @@ -0,0 +1,383 @@ +from collections.abc import Sequence +from typing import Any, Optional + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.condition.entities import Condition +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class ConditionProcessor: + def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): + input_conditions = [] + group_result = [] + + index = 0 + for condition in conditions: + index += 1 + actual_value = variable_pool.get_any( + condition.variable_selector + ) + + expected_value = None + if condition.value is not None: + variable_template_parser = VariableTemplateParser(template=condition.value) + variable_selectors = variable_template_parser.extract_variable_selectors() + if variable_selectors: + for variable_selector in variable_selectors: + value = variable_pool.get_any( + variable_selector.value_selector + ) + expected_value = variable_template_parser.format({variable_selector.variable: value}) + + if expected_value is None: + expected_value = condition.value + else: + expected_value = condition.value + + comparison_operator = condition.comparison_operator + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": comparison_operator + } + ) + + result = self.evaluate_condition(actual_value, comparison_operator, expected_value) + group_result.append(result) + + return input_conditions, group_result + + def evaluate_condition( + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None + ) -> bool: + """ + Evaluate condition + :param actual_value: actual value + :param expected_value: expected value + :param comparison_operator: comparison operator + + :return: bool + """ + if comparison_operator == "contains": + return self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + return self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + return self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + return self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + return self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + return self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + return self._assert_empty(actual_value) + elif comparison_operator == "not empty": + return self._assert_not_empty(actual_value) + elif comparison_operator == "=": + return self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + return self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + return self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + return self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + return self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + return self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + return self._assert_null(actual_value) + elif comparison_operator == "not null": + return self._assert_not_null(actual_value) + else: + raise ValueError(f"Invalid comparison operator: {comparison_operator}") + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + +class ConditionAssertionError(Exception): + def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: + self.message = message + self.conditions = conditions + self.sub_condition_compare_results = sub_condition_compare_results + super().__init__(self.message) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3157eedfee5238..e69de29bb2d1d6 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,1005 +0,0 @@ -import logging -import time -from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast - -import contexts -from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.entities import IterationState -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode -from core.workflow.nodes.variable_assigner import VariableAssignerNode -from extensions.ext_database import db -from models.workflow import ( - Workflow, - WorkflowNodeExecutionStatus, -) - -node_classes: Mapping[NodeType, type[BaseNode]] = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, - NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, -} - -logger = logging.getLogger(__name__) - - -class WorkflowEngineManager: - def get_default_configs(self) -> list[dict]: - """ - Get default block configs - """ - default_block_configs = [] - for node_type, node_class in node_classes.items(): - default_config = node_class.get_default_config() - if default_config: - default_block_configs.append(default_config) - - return default_block_configs - - def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: - """ - Get default config of node. - :param node_type: node type - :param filters: filter by node config parameters. - :return: - """ - node_class = node_classes.get(node_type) - if not node_class: - return None - - default_config = node_class.get_default_config(filters=filters) - if not default_config: - return None - - return default_config - - def run_workflow( - self, - *, - workflow: Workflow, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - callbacks: Sequence[WorkflowCallback], - call_depth: int = 0, - variable_pool: VariablePool | None = None, - ) -> None: - """ - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param invoke_from: invoke from - :param callbacks: workflow callbacks - :param call_depth: call depth - :param variable_pool: variable pool - """ - # fetch workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - if 'nodes' not in graph or 'edges' not in graph: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - - - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH - if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - - # init workflow run state - if not variable_pool: - variable_pool = contexts.workflow_variable_pool.get() - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - workflow_call_depth=call_depth - ) - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - ) - - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: - """ - Run workflow - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :param call_depth: call depth - :param start_at: force specific start node - :param end_at: force specific end node - :return: - """ - graph = workflow.graph_dict - - try: - answer_prov_node_ids = [] - for node in graph.get('nodes', []): - if node.get('id', '') == 'answer': - try: - answer_prov_node_ids.append(node.get('data', {}) - .get('answer', '') - .replace('#', '') - .replace('.text', '') - .replace('{{', '') - .replace('}}', '').split('.')[0]) - except Exception as e: - logger.error(e) - - predecessor_node: BaseNode | None = None - current_iteration_node: BaseIterationNode | None = None - has_entry_node = False - max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS - max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME - while True: - # get next node, multiple target nodes in the future - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=predecessor_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - - if not next_node: - # reached loop/iteration end or overall end - if current_iteration_node and workflow_run_state.current_iteration_state: - # reached loop/iteration end - # get next iteration - next_iteration = current_iteration_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_iteration, NodeRunResult): - if next_iteration.outputs: - for variable_key, variable_value in next_iteration.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=current_iteration_node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - # iteration has ended - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=current_iteration_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - # continue overall process - elif isinstance(next_iteration, str): - # move to next iteration - next_node_id = next_iteration - # get next id - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if not next_node: - break - - # check is already ran - if self._check_node_has_ran(workflow_run_state, next_node.node_id): - predecessor_node = next_node - continue - - has_entry_node = True - - # max steps reached - if workflow_run_state.workflow_node_steps > max_execution_steps: - raise ValueError('Max steps {} reached.'.format(max_execution_steps)) - - # or max execution time reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) - - # handle iteration nodes - if isinstance(next_node, BaseIterationNode): - current_iteration_node = next_node - workflow_run_state.current_iteration_state = next_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - callbacks=callbacks - ) - predecessor_node = next_node - # move to start node of iteration - next_node_id = next_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - continue - else: - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if next_node and next_node.node_id in answer_prov_node_ids: - next_node.is_answer_previous_node = True - - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=next_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) - - if next_node.node_type in [NodeType.END]: - break - - predecessor_node = next_node - - if not has_entry_node: - self._workflow_run_failed( - error='Start node not found in workflow graph.', - callbacks=callbacks - ) - return - except GenerateTaskStoppedException as e: - return - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - return - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - def single_step_run_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: - """ - Single step run workflow node - :param workflow: Workflow instance - :param node_id: node id - :param user_id: user id - :param user_inputs: user inputs - :return: - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - # fetch node config from node id - node_config = None - for node in nodes: - if node.get('id') == node_id: - node_config = node - break - - if not node_config: - raise ValueError('node id not found in workflow graph') - - # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) - - # init workflow run state - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - workflow_call_depth=0 - ) - - try: - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - if node_cls is None: - raise ValueError('Node class not found') - # variable selector to variable mapping - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # run node - node_run_result = node_instance.run( - variable_pool=variable_pool - ) - - # sign output files - node_run_result.outputs = self.handle_special_values(node_run_result.outputs) - except Exception as e: - raise WorkflowNodeRunFailedError( - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_title=node_instance.node_data.title, - error=str(e) - ) - - return node_instance, node_run_result - - def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[WorkflowCallback], - ) -> None: - """ - Single iteration run workflow node - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - for node in nodes: - if node.get('id') == node_id: - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]: - node_config = node - else: - raise ValueError('node id is not an iteration node') - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - # variable selector to variable mapping - iteration_nested_nodes = [ - node for node in nodes - if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id - ] - iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - for node_config in iteration_nested_nodes: - # mapping user inputs to variable pool - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - if node_cls is None: - raise ValueError('Node class not found') - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - # remove iteration variables - variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() - if value[0] != node_id - } - - # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_nested_node_ids - } - - # append variables to variable pool - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - callbacks=callbacks, - workflow_call_depth=0 - ) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # fetch end node of iteration - end_node_id = None - for edge in graph.get('edges'): - if edge.get('source') == node_id: - end_node_id = edge.get('target') - break - - if not end_node_id: - raise ValueError('end node of iteration not found') - - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - workflow_call_depth=0 - ) - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - start_at=node_id, - end_at=end_node_id - ) - - def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run success - :param callbacks: workflow callbacks - :return: - """ - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_succeeded() - - def _workflow_run_failed(self, error: str, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run failed - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - if callbacks: - for callback in callbacks: - callback.on_workflow_run_failed( - error=error - ) - - def _workflow_iteration_started(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - predecessor_node_id: Optional[str] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration started - :param current_iteration_node: current iteration node - :param workflow_run_state: workflow run state - :param callbacks: workflow callbacks - :return: - """ - # get nested nodes - iteration_nested_nodes = [ - node for node in graph.get('nodes') - if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id - ] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_started( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - node_data=current_iteration_node.node_data, - inputs=workflow_run_state.current_iteration_state.inputs, - predecessor_node_id=predecessor_node_id, - metadata=workflow_run_state.current_iteration_state.metadata.model_dump() - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - def _workflow_iteration_next(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration next - :param workflow_run_state: workflow run state - :return: - """ - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_next( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - index=workflow_run_state.current_iteration_state.index, - node_run_index=workflow_run_state.workflow_node_steps, - output=workflow_run_state.current_iteration_state.get_current_output() - ) - # clear ran nodes - workflow_run_state.workflow_node_runs = [ - node_run for node_run in workflow_run_state.workflow_node_runs - if node_run.iteration_node_id != current_iteration_node.node_id - ] - - # clear variables in current iteration - nodes = graph.get('nodes') - nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] - - for node in nodes: - workflow_run_state.variable_pool.remove((node.get('id'),)) - - def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_completed( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - outputs={ - 'output': workflow_run_state.current_iteration_state.outputs - } - ) - - def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> Optional[BaseNode]: - """ - Get next node - multiple target nodes in the future. - :param graph: workflow graph - :param predecessor_node: predecessor node - :param callbacks: workflow callbacks - :return: - """ - nodes = graph.get('nodes') - if not nodes: - return None - - if not predecessor_node: - for node_config in nodes: - node_cls = None - if start_at: - if node_config.get('id') == start_at: - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - else: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - node_cls = StartNode - if node_cls: - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - else: - edges = graph.get('edges') - source_node_id = predecessor_node.node_id - - # fetch all outgoing edges from source node - outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] - if not outgoing_edges: - return None - - # fetch target node id from outgoing edges - outgoing_edge = None - source_handle = predecessor_node.node_run_result.edge_source_handle \ - if predecessor_node.node_run_result else None - if source_handle: - for edge in outgoing_edges: - if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - outgoing_edge = edge - break - else: - outgoing_edge = outgoing_edges[0] - - if not outgoing_edge: - return None - - target_node_id = outgoing_edge.get('target') - - if end_at and target_node_id == end_at: - return None - - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break - - if not target_node_config: - return None - - # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) - - return target_node( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - node_id: str, - callbacks: Sequence[WorkflowCallback]): - """ - Get node from graph by node id - """ - nodes = graph.get('nodes') - if not nodes: - return None - - for node_config in nodes: - if node_config.get('id') == node_id: - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes[node_type] - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: - """ - Check timeout - :param start_at: start time - :param max_execution_time: max execution time - :return: - """ - return time.perf_counter() - start_at > max_execution_time - - def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: - """ - Check node has ran - """ - return bool([ - node_and_result for node_and_result in workflow_run_state.workflow_node_runs - if node_and_result.node_id == node_id - ]) - - def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - node_run_index=workflow_run_state.workflow_node_steps, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None - ) - - db.session.close() - - workflow_nodes_and_result = WorkflowNodeAndResult( - node=node, - result=None - ) - - # add to workflow_nodes_and_results - workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - # mark node as running - if workflow_run_state.current_iteration_state: - workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( - node_id=node.node_id, - iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id - )) - - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) - except GenerateTaskStoppedException as e: - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error='Workflow stopped.' - ) - except Exception as e: - logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - # node run failed - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_failed( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - error=node_run_result.error, - inputs=node_run_result.inputs, - outputs=node_run_result.outputs, - process_data=node_run_result.process_data, - ) - - raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - - if node.is_answer_previous_node and not isinstance(node, LLMNode): - if not node_run_result.metadata: - node_run_result.metadata = {} - node_run_result.metadata["is_answer_previous_node"]=True - workflow_nodes_and_result.result = node_run_result - - # node run success - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_succeeded( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - inputs=node_run_result.inputs, - process_data=node_run_result.process_data, - outputs=node_run_result.outputs, - execution_metadata=node_run_result.metadata - ) - - if node_run_result.outputs: - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - - if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - db.session.close() - - def _append_variables_recursively(self, variable_pool: VariablePool, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): - """ - Append variables recursively - :param variable_pool: variable pool - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add( - [node_id] + variable_key_list, variable_value - ) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value - ) - - @classmethod - def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None - - new_value = value.copy() - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value - - def _mapping_user_inputs_to_variable_pool(self, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, - variable_pool: VariablePool, - tenant_id: str, - node_instance: BaseNode): - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # FIXME: temp fix for image type - if node_instance.node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.add([variable_node_id]+variable_key_list, value) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py new file mode 100644 index 00000000000000..a359bd606ea2e2 --- /dev/null +++ b/api/core/workflow/workflow_entry.py @@ -0,0 +1,314 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional, cast + +from configs import dify_config +from core.app.app_config.entities import FileExtraConfig +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunEvent +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.node_mapping import node_classes +from models.workflow import ( + Workflow, + WorkflowType, +) + +logger = logging.getLogger(__name__) + + +class WorkflowEntry: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None + ) -> None: + """ + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph + :param user_id: user id + :param user_from: user from + :param invoke_from: invoke from + :param call_depth: call depth + :param variable_pool: variable pool + :param thread_pool_id: thread pool id + """ + # check call depth + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + + # init workflow run state + self.graph_engine = GraphEngine( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=thread_pool_id + ) + + def run( + self, + *, + callbacks: Sequence[WorkflowCallback], + ) -> Generator[GraphEngineEvent, None, None]: + """ + :param callbacks: workflow callbacks + """ + graph_engine = self.graph_engine + + try: + # run workflow + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event( + event=event + ) + yield event + except GenerateTaskStoppedException: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + if callbacks: + for callback in callbacks: + callback.on_event( + event=GraphRunFailedEvent( + error=str(e) + ) + ) + return + + @classmethod + def single_step_run( + cls, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict + ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get('id') == node_id: + node_config = node + break + + if not node_config: + raise ValueError('node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + if not node_cls: + raise ValueError(f'Node class not found for node type {node_type}') + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + # init graph + graph = Graph.init( + graph_config=workflow.graph_dict + ) + + # init workflow run state + node_instance: BaseNode = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ), + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter() + ) + ) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=node_instance.node_data + ) + + # run node + generator = node_instance.run() + + return node_instance, generator + except Exception as e: + raise WorkflowNodeRunFailedError( + node_instance=node_instance, + error=str(e) + ) + + @classmethod + def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: + """ + Handle special values + :param value: value + :return: + """ + if not value: + return None + + new_value = dict(value) if value else {} + if isinstance(new_value, dict): + for key, val in new_value.items(): + if isinstance(val, FileVar): + new_value[key] = val.to_dict() + elif isinstance(val, list): + new_val = [] + for v in val: + if isinstance(v, FileVar): + new_val.append(v.to_dict()) + else: + new_val.append(v) + + new_value[key] = new_val + + return new_value + + @classmethod + def mapping_user_inputs_to_variable_pool( + cls, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_type: NodeType, + node_data: BaseNodeData + ) -> None: + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split('.') + if len(node_variable_list) < 1: + raise ValueError(f'Invalid node variable {node_variable}') + + node_variable_key = '.'.join(node_variable_list[1:]) + + if ( + node_variable_key not in user_inputs + and node_variable not in user_inputs + ) and not variable_pool.get(variable_selector): + raise ValueError(f'Variable key {node_variable} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + variable_key_list = cast(list[str], variable_key_list) + + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) + + # FIXME: temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(input_value, list): + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in input_value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + + # append variable and value to variable pool + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 00000000000000..55824945da49b0 --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: 030f4915f36a +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = '030f4915f36a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index cdd5e1992ded49..e78b5666bcfcdf 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model): 'triggered_from', 'workflow_run_id'), db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'), + db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_execution_id'), ) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) @@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model): workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) + node_execution_id = db.Column(db.String(255), nullable=True) node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 895855a9c86bf3..73c446b83b7e79 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,8 +13,9 @@ logger = logging.getLogger(__name__) -current_dsl_version = "0.1.1" +current_dsl_version = "0.1.2" dsl_to_dify_version_mapping: dict[str, str] = { + "0.1.2": "0.8.0", "0.1.1": "0.6.0", # dsl version -> from dify version } diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 747505977fe282..26517a05fb9873 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService @@ -103,9 +104,7 @@ def _get_max_active_requests(app_model: App) -> int: return max_active_requests @classmethod - def generate_single_iteration( - cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True - ): + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( @@ -142,7 +141,7 @@ def generate_more_like_this( ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4c3ded14ade386..357ffd41c127a5 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,9 +8,11 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account @@ -172,8 +174,13 @@ def get_default_block_configs(self) -> list[dict]: Get default block configs """ # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_configs() + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: """ @@ -182,11 +189,18 @@ def get_default_block_config(self, node_type: str, filters: Optional[dict] = Non :param filters: filter by node config parameters. :return: """ - node_type = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType.value_of(node_type) # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_config(node_type, filters) + node_class = node_classes.get(node_type_enum) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config def run_draft_workflow_node( self, app_model: App, node_id: str, user_inputs: dict, account: Account @@ -200,82 +214,68 @@ def run_draft_workflow_node( raise ValueError("Workflow not initialized") # run draft workflow node - workflow_engine_manager = WorkflowEngineManager() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + node_instance, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ) - except WorkflowNodeRunFailedError as e: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=e.node_id, - node_type=e.node_type.value, - title=e.node_title, - status=WorkflowNodeExecutionStatus.FAILED.value, - error=e.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) - db.session.add(workflow_node_execution) - db.session.commit() - return workflow_node_execution + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") - if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + error = node_run_result.error if not run_succeeded else None + except WorkflowNodeRunFailedError as e: + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), + workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None + workflow_node_execution.process_data = ( + json.dumps(node_run_result.process_data) if node_run_result.process_data else None + ) + workflow_node_execution.outputs = ( + json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + ) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - status=node_run_result.status.value, - error=node_run_result.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error db.session.add(workflow_node_execution) db.session.commit() @@ -321,25 +321,3 @@ def validate_features_structure(self, app_model: App, features: dict) -> dict: ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") - - @classmethod - def get_elapsed_time(cls, workflow_run_id: str) -> float: - """ - Get elapsed time - """ - elapsed_time = 0.0 - - # fetch workflow node execution by workflow_run_id - workflow_nodes = ( - db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) - .order_by(WorkflowNodeExecution.created_at.asc()) - .all() - ) - if not workflow_nodes: - return elapsed_time - - for node in workflow_nodes: - elapsed_time += node.elapsed_time - - return elapsed_time diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 6f5421e108f0c7..952c90674da7d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,17 +1,72 @@ +import time +import uuid from os import getenv +from typing import cast import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.nodes.code.entities import CodeNodeData +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) +def init_code_node(code_config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["code", "123", "args1"], 1) + variable_pool.add(["code", "123", "args2"], 2) + + node = CodeNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=code_config, + ) + + return node + + @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """ @@ -22,44 +77,36 @@ def main(args1: int, args2: int) -> dict: """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "number", - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + node = init_code_node(code_config) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["result"] == 3 assert result.error is None @@ -74,44 +121,34 @@ def main(args1: int, args2: int) -> dict: """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "string", - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "string", }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + node = init_code_node(code_config) # execute node - result = node.run(pool) - + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error == "Output variable `result` must be a string" @@ -127,65 +164,60 @@ def main(args1: int, args2: int) -> dict: """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "string_validator": { - "type": "string", - }, - "number_validator": { - "type": "number", - }, - "number_array_validator": { - "type": "array[number]", - }, - "string_array_validator": { - "type": "array[string]", - }, - "object_validator": { - "type": "object", - "children": { - "result": { - "type": "number", - }, - "depth": { - "type": "object", - "children": { - "depth": { - "type": "object", - "children": { - "depth": { - "type": "number", - } - }, - } - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + }, + } }, }, }, }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -196,6 +228,8 @@ def main(args1: int, args2: int) -> dict: "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) @@ -250,35 +284,30 @@ def main(args1: int, args2: int) -> dict: """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, - config={ - "id": "1", - "data": { - "outputs": { - "object_list": { - "type": "array[object]", - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "object_list": { + "type": "array[object]", }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -295,6 +324,8 @@ def main(args1: int, args2: int) -> dict: ] } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index acb616b3256ac3..65aaa0bddd5e95 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,31 +1,69 @@ +import time +import uuid from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock -BASIC_NODE_DATA = { - "tenant_id": "1", - "app_id": "1", - "workflow_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.WEB_APP, -} -# construct variable pool -pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(["a", "b123", "args1"], 1) -pool.add(["a", "b123", "args2"], 2) +def init_http_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return HttpRequestNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -45,12 +83,11 @@ def test_get(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -59,7 +96,7 @@ def test_get(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_no_auth(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -143,11 +178,11 @@ def test_template(setup_http_mock): "params": "A:b\nTemplate:{{#a.b123.args2#}}", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -158,7 +193,7 @@ def test_template(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -178,11 +213,11 @@ def test_json(setup_http_mock): "params": "A:b", "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert '{"a": "1"}' in data @@ -190,7 +225,7 @@ def test_json(setup_http_mock): def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock): "params": "A:b", "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "a=1&b=2" in data @@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock): def test_form_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -242,11 +277,11 @@ def test_form_data(setup_http_mock): "params": "A:b", "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert 'form-data; name="a"' in data @@ -257,7 +292,7 @@ def test_form_data(setup_http_mock): def test_none_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -277,11 +312,11 @@ def test_none_data(setup_http_mock): "params": "A:b", "body": {"type": "none", "data": "123123123"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "X-Header: 123" in data @@ -289,7 +324,7 @@ def test_none_data(setup_http_mock): def test_mock_404(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock): "params": "", "headers": "X-Header:123", }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.outputs is not None resp = result.outputs assert 404 == resp.get("status_code") - assert "Not Found" in resp.get("body") + assert "Not Found" in resp.get("body", "") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock): "headers": "Referer:http://example3.com\nRedirect:http://example4.com", "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None + assert result.outputs is not None resp = result.outputs - assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") - assert "http://example3.com" == resp.get("headers").get("referer") + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "") + assert "http://example3.com" == resp.get("headers", {}).get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6bab83a0191bca..dfb43650d20c1c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,5 +1,8 @@ import json import os +import time +import uuid +from collections.abc import Generator from unittest.mock import MagicMock import pytest @@ -10,28 +13,77 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db from models.provider import ProviderType -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_execute_llm(setup_openai_mock): - node = LLMNode( +def init_llm_node(config: dict) -> LLMNode: + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["abc", "output"], "sunny") + + node = LLMNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + return node + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_execute_llm(setup_openai_mock): + node = init_llm_node( config={ "id": "llm", "data": { @@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock): model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, Generator) - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["text"] is not None - assert result.outputs["usage"]["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ - node = LLMNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_llm_node( config={ "id": "llm", "data": { @@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") - + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) - - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "sunny" in json.dumps(result.process_data) - assert "what's the weather today?" in json.dumps(result.process_data) + result = node._run() + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ca2bae5c536090..cbe9c5914f1f0d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,5 +1,7 @@ import json import os +import time +import uuid from typing import Optional from unittest.mock import MagicMock @@ -8,19 +10,21 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db from models.provider import ProviderType """FOR MOCK FIXTURES, DO NOT REMOVE""" -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -47,13 +51,15 @@ def get_mocked_fetch_model_config( model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_schema = model_type_instance.get_model_schema(model) + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model=model, provider=provider, mode=mode, credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema(model), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -74,18 +80,62 @@ def get_history_prompt_text( return MagicMock(return_value=MemoryMock()) -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_function_calling_parameter_extractor(setup_openai_mock): - """ - Test function calling for parameter extractor. - """ - node = ParameterExtractorNode( +def init_parameter_extractor_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return ParameterExtractorNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_function_calling_parameter_extractor(setup_openai_mock): + """ + Test function calling for parameter extractor. + """ + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock): "reasoning_mode": "function_call", "memory": None, }, - }, + } ) node._fetch_model_config = get_mocked_fetch_model_config( @@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): environment_variables=[], ) - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None @@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None process_data = result.process_data + assert process_data is not None process_data.get("prompts") - for prompt in process_data.get("prompts"): + for prompt in process_data.get("prompts", []): if prompt.get("role") == "system": assert "what's the weather in SF" in prompt.get("text") @@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) for prompt in prompts: if prompt.get("role") == "user": @@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - assert len(result.process_data.get("prompts")) == 1 - assert "SF" in result.process_data.get("prompts")[0].get("text") + assert result.process_data is not None + assert len(result.process_data.get("prompts", [])) == 1 + assert "SF" in result.process_data.get("prompts", [])[0].get("text") def test_extract_json_response(): @@ -322,13 +325,7 @@ def test_extract_json_response(): Test extract json response. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -357,6 +354,7 @@ def test_extract_json_response(): hello world. """) + assert result is not None assert result["location"] == "kawaii" @@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) latest_role = None for prompt in prompts: diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 617b6370c9f410..073c4bb7996d33 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,46 +1,84 @@ +import time +import uuid + import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """{{args2}}""" - node = TemplateTransformNode( + config = { + "id": "1", + "data": { + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "template": code, + }, + } + + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.END_USER, - config={ - "id": "1", - "data": { - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "template": code, - }, - }, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 3) + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["1", "123", "args1"], 1) + variable_pool.add(["1", "123", "args2"], 3) + + node = TemplateTransformNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) # execute node - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 29c1efa8e78b2a..4d94cdb28af4cd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,21 +1,62 @@ +import time +import uuid + from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.tool.tool_node import ToolNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType -def test_tool_variable_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], "1+1") +def init_tool_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } - node = ToolNode( + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + return ToolNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + +def test_tool_variable_invoke(): + node = init_tool_node( config={ "id": "1", "data": { @@ -34,28 +75,22 @@ def test_tool_variable_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] def test_tool_mixed_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "args1"], "1+1") - - node = ToolNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_tool_node( config={ "id": "1", "data": { @@ -74,12 +109,15 @@ def test_tool_mixed_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index afc9802cf1cbe7..ca3082953aef5c 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -1,7 +1,24 @@ import os +import pytest +from flask import Flask + # Getting the absolute path of the current file's directory ABS_PATH = os.path.dirname(os.path.abspath(__file__)) # Getting the absolute path of the project's root directory PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + +CACHED_APP = Flask(__name__) +CACHED_APP.config.update({"TESTING": True}) + + +@pytest.fixture() +def app() -> Flask: + return CACHED_APP + + +@pytest.fixture(autouse=True) +def _provide_app_context(app: Flask): + with app.app_context(): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py new file mode 100644 index 00000000000000..65757cd604cfa8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -0,0 +1,791 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "question-classifier"}, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_iteration_graph(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "iteration"}, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + }, + ], + } + + graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration") + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")], + ), + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" + + +def test_parallels_graph(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + start_edges = graph.edge_mapping.get("start") + assert start_edges is not None + assert start_edges[i].target_node_id == f"llm{i+1}" + + llm_edges = graph.edge_mapping.get(f"llm{i+1}") + assert llm_edges is not None + assert llm_edges[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph2(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + if i < 2: + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph3(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph4(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "code2", + }, + { + "id": "llm3-source-code3-target", + "source": "llm3", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" + assert graph.edge_mapping.get(f"code{i + 1}") is not None + assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph5(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm4", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm5", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-code1-target", + "source": "llm2", + "target": "code1", + }, + { + "id": "llm3-source-code2-target", + "source": "llm3", + "target": "code2", + }, + { + "id": "llm4-source-code2-target", + "source": "llm4", + "target": "code2", + }, + { + "id": "llm5-source-code3-target", + "source": "llm5", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(5): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm3") is not None + assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm4") is not None + assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm5") is not None + assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 8 + + for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph6(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm1-source-code2-target", + "source": "llm1", + "target": "code2", + }, + { + "id": "llm2-source-code3-target", + "source": "llm2", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code3") is not None + assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 2 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + parent_parallel = None + child_parallel = None + for p_id, parallel in graph.parallel_mapping.items(): + if parallel.parent_parallel_id is None: + parent_parallel = parallel + else: + child_parallel = parallel + + for node_id in ["llm1", "llm2", "llm3", "code3"]: + assert graph.node_parallel_mapping[node_id] == parent_parallel.id + + for node_id in ["code1", "code2"]: + assert graph.node_parallel_mapping[node_id] == child_parallel.id diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py new file mode 100644 index 00000000000000..a2d71d61fcef52 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -0,0 +1,505 @@ +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + BaseNodeEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.llm_node import LLMNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_workflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "llm1", + }, + { + "id": "2", + "source": "llm1", + "target": "llm2", + }, + { + "id": "3", + "source": "llm1", + "target": "llm3", + }, + { + "id": "4", + "source": "llm2", + "target": "end1", + }, + { + "id": "5", + "source": "llm3", + "target": "end2", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + "title": "start", + "variables": [ + { + "label": "query", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "query", + } + ], + }, + "id": "start", + }, + { + "data": { + "type": "llm", + "title": "llm1", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say hi"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + "title": "llm2", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say bye"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + "title": "llm3", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say good morning"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm3", + }, + { + "data": { + "type": "end", + "title": "end1", + "outputs": [ + {"value_selector": ["llm2", "text"], "variable": "result2"}, + {"value_selector": ["start", "query"], "variable": "query"}, + ], + }, + "id": "end1", + }, + { + "data": { + "type": "end", + "title": "end2", + "outputs": [ + {"value_selector": ["llm1", "text"], "variable": "result1"}, + {"value_selector": ["llm3", "text"], "variable": "result3"}, + ], + }, + "id": "end2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent( + chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"] + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + # print("") + + with patch.object(LLMNode, "_run", new=llm_generator): + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + assert item.parallel_id is not None + + assert len(items) == 18 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_chatflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "answer1", + }, + { + "id": "2", + "source": "answer1", + "target": "answer2", + }, + { + "id": "3", + "source": "answer1", + "target": "answer3", + }, + { + "id": "4", + "source": "answer2", + "target": "answer4", + }, + { + "id": "5", + "source": "answer3", + "target": "answer5", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "start"}, "id": "start"}, + {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"}, + { + "data": {"type": "answer", "title": "answer2", "answer": "2"}, + "id": "answer2", + }, + { + "data": {"type": "answer", "title": "answer3", "answer": "3"}, + "id": "answer3", + }, + { + "data": {"type": "answer", "title": "answer4", "answer": "4"}, + "id": "answer4", + }, + { + "data": {"type": "answer", "title": "answer5", "answer": "5"}, + "id": "answer5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + "answer2", + "answer3", + "answer4", + "answer5", + ]: + assert item.parallel_id is not None + + assert len(items) == 23 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_branch(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "if-else-1", + }, + { + "id": "2", + "source": "if-else-1", + "sourceHandle": "true", + "target": "answer-1", + }, + { + "id": "3", + "source": "if-else-1", + "sourceHandle": "false", + "target": "if-else-2", + }, + { + "id": "4", + "source": "if-else-2", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "5", + "source": "if-else-2", + "sourceHandle": "false", + "target": "answer-3", + }, + ], + "nodes": [ + { + "data": { + "title": "Start", + "type": "start", + "variables": [ + { + "label": "uid", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "uid", + } + ], + }, + "id": "start", + }, + { + "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []}, + "id": "answer-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", + "value": "hi", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "title": "IF/ELSE", + "type": "if-else", + }, + "id": "if-else-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "ae895199-5608-433b-b5f0-0997ae1431e4", + "value": "takatost", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "title": "IF/ELSE 2", + "type": "if-else", + }, + "id": "if-else-2", + }, + { + "data": { + "answer": "2", + "title": "Answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "answer": "3", + "title": "Answer 3", + "type": "answer", + }, + "id": "answer-3", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "hi", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + + assert len(items) == 10 + assert items[3].route_node_state.node_id == "if-else-1" + assert items[4].route_node_state.node_id == "if-else-1" + assert isinstance(items[5], NodeRunStreamChunkEvent) + assert items[5].chunk_content == "1 " + assert isinstance(items[6], NodeRunStreamChunkEvent) + assert items[6].chunk_content == "takato" + assert items[7].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato" + assert isinstance(items[9], GraphRunSucceededEvent) + + # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py b/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py new file mode 100644 index 00000000000000..fe4ede6335fc63 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -0,0 +1,82 @@ +import time +import uuid +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.answer.answer_node import AnswerNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_answer(): + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py new file mode 100644 index 00000000000000..bce87536d8e995 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py @@ -0,0 +1,109 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + answer_stream_generate_route = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping + ) + + assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"] + assert answer_stream_generate_route.answer_dependencies["answer2"] == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py new file mode 100644 index 00000000000000..6b1d1e9070a3a6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -0,0 +1,216 @@ +import uuid +from collections.abc import Generator +from datetime import datetime, timezone + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.start.entities import StartNodeData + + +def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + if next_node_id == "start": + yield from _publish_events(graph, next_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _publish_events(graph, edge.target_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _recursive_process(graph, edge.target_node_id) + + +def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + + parallel_id = graph.node_parallel_mapping.get(next_node_id) + parallel_start_node_id = None + if parallel_id: + parallel = graph.parallel_mapping.get(parallel_id) + parallel_start_node_id = parallel.start_from_node_id if parallel else None + + node_execution_id = str(uuid.uuid4()) + node_config = graph.node_id_config_mapping[next_node_id] + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + mock_node_data = StartNodeData(**{"title": "demo", "variables": []}) + + yield NodeRunStartedEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_start_node_id=parallel_start_node_id, + ) + + if "llm" in next_node_id: + length = int(next_node_id[-1]) + for i in range(0, length): + yield NodeRunStreamChunkEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + chunk_content=str(i), + route_node_state=route_node_state, + from_variable_selector=[next_node_id, "text"], + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + route_node_state.status = RouteNodeState.Status.SUCCESS + route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + yield NodeRunSucceededEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + +def test_process(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool) + + def graph_generator() -> Generator[GraphEngineEvent, None, None]: + # print("") + for event in _recursive_process(graph, "start"): + # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunSucceededEvent): + if "llm" in event.route_node_state.node_id: + variable_pool.add( + [event.route_node_state.node_id, "text"], + "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))), + ) + yield event + + result_generator = answer_stream_processor.process(graph_generator()) + stream_contents = "" + for event in result_generator: + # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunStreamChunkEvent): + stream_contents += event.chunk_content + pass + + assert stream_contents == "c012da01b" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py new file mode 100644 index 00000000000000..b3a89061b29e21 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -0,0 +1,420 @@ +import time +import uuid +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_run(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 20 + + +def test_run_parallel(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 32 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 8020674ee6e3d9..cb2e99a854dfcc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,22 +1,70 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_answer(): - node = AnswerNode( + graph_config = { + "edges": [ + { + "id": "start-source-answer-target", + "source": "start", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "answer", "data": { @@ -27,20 +75,11 @@ def test_execute_answer(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "weather"], "sunny") - pool.add(["llm", "text"], "You are a helpful AI.") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 9535bc2186af72..0795f134d00d8a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,22 +1,63 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_if_else_result_true(): - node = IfElseNode( + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -63,48 +104,64 @@ def test_execute_if_else_result_true(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["ab", "def"]) - pool.add(["start", "array_not_contains"], ["ac", "def"]) - pool.add(["start", "contains"], "cabcde") - pool.add(["start", "not_contains"], "zacde") - pool.add(["start", "start_with"], "abc") - pool.add(["start", "end_with"], "zzab") - pool.add(["start", "is"], "ab") - pool.add(["start", "is_not"], "aab") - pool.add(["start", "empty"], "") - pool.add(["start", "not_empty"], "aaa") - pool.add(["start", "equals"], 22) - pool.add(["start", "not_equals"], 23) - pool.add(["start", "greater_than"], 23) - pool.add(["start", "less_than"], 21) - pool.add(["start", "greater_than_or_equal"], 22) - pool.add(["start", "less_than_or_equal"], 21) - pool.add(["start", "not_null"], "1212") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is True def test_execute_if_else_result_false(): - node = IfElseNode( + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -127,20 +184,11 @@ def test_execute_if_else_result_false(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["1ab", "def"]) - pool.add(["start", "array_not_contains"], ["ab", "def"]) - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index e26c7df642776b..f45a93f1be58ce 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -1,17 +1,56 @@ +import time +import uuid from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode +from models.workflow import WorkflowType DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = StringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -24,36 +63,36 @@ def test_overwrite_string_variable(): value="the second value", ) - node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config={ - "id": "node_id", - "data": { - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, - ) - + # construct variable pool variable_pool = VariablePool( system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], ) + variable_pool.add( [DEFAULT_NODE_ID, input_variable.name], input_variable, ) + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -63,6 +102,39 @@ def test_overwrite_string_variable(): def test_append_variable_to_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -75,23 +147,6 @@ def test_append_variable_to_array(): value="the second value", ) - node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config={ - "id": "node_id", - "data": { - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, - ) - variable_pool = VariablePool( system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, @@ -103,8 +158,23 @@ def test_append_variable_to_array(): input_variable, ) + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -113,19 +183,57 @@ def test_append_variable_to_array(): def test_clear_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", value=["the first value"], ) + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "node_id", "data": { @@ -136,14 +244,9 @@ def test_clear_array(): }, ) - variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - node.run(variable_pool) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index bc42d4b8f78374..7075a31f2b8e76 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -396,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.3 + image: langgenius/dify-web:0.8.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aca87a6f8f4227..68f897ddb92a6b 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -208,7 +208,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -228,7 +228,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -247,7 +247,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.3 + image: langgenius/dify-web:0.8.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 5f36e40c40fe54..1f17798f83c8e9 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -11,10 +11,10 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import type { ChatItem, WorkflowProcess } from '../../types' +import TracingPanel from '@/app/components/workflow/run/tracing-panel' import cn from '@/utils/classnames' import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general' import { WorkflowRunningStatus } from '@/app/components/workflow/types' -import NodePanel from '@/app/components/workflow/run/node' import { useStore as useAppStore } from '@/app/components/app/store' type WorkflowProcessProps = { @@ -107,16 +107,12 @@ const WorkflowProcessItem = ({ !collapse && (
{ - data.tracing.map(node => ( -
- -
- )) + }
) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index e6638b8eeda20b..892f88c4ad72bb 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -241,8 +241,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) hasStopResponded.current = false @@ -503,12 +501,13 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], + const iterationIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationIndex] = { + ...tracing[iterationIndex], ...data, status: WorkflowRunningStatus.Succeeded, } as any @@ -520,10 +519,9 @@ export const useChat = ( ...responseItem, } })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return responseItem.workflowProcess!.tracing!.push({ @@ -539,10 +537,15 @@ export const useChat = ( })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata.parallel_id) + }) responseItem.workflowProcess!.tracing[currentIndex] = data as any handleUpdateChatList(produce(chatListRef.current, (draft) => { const currentIndex = draft.findIndex(item => item.id === responseItem.id) diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 021a98d4d3b31b..96fe9f01ef20ec 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -196,8 +196,6 @@ const Result: FC = ({ })() if (isWorkflow) { - let isInIteration = false - sendWorkflowMessage( data, { @@ -219,23 +217,28 @@ const Result: FC = ({ expand: true, } as any) })) - isInIteration = true }, onIterationNext: () => { + setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + draft.expand = true + const iterations = draft.tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + iterations?.details!.push([]) + })) }, onIterationFinish: ({ data }) => { setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true - // const iteration = draft.tracing![draft.tracing!.length - 1] - draft.tracing![draft.tracing!.length - 1] = { + const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + draft.tracing[iterationsIndex] = { ...data, expand: !!data.error, } as any })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { @@ -248,11 +251,12 @@ const Result: FC = ({ })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) + const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id + && (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id)) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras diff --git a/web/app/components/workflow/candidate-node.tsx b/web/app/components/workflow/candidate-node.tsx index faae545b3b5dde..16d6f852b2fe26 100644 --- a/web/app/components/workflow/candidate-node.tsx +++ b/web/app/components/workflow/candidate-node.tsx @@ -14,9 +14,11 @@ import { } from './store' import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks' import { CUSTOM_NODE } from './constants' +import { getIterationStartNode } from './utils' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' +import { BlockEnum } from './types' const CandidateNode = () => { const store = useStoreApi() @@ -52,6 +54,8 @@ const CandidateNode = () => { y, }, }) + if (candidateNode.data.type === BlockEnum.Iteration) + draft.push(getIterationStartNode(candidateNode.id)) }) setNodes(newNodes) if (candidateNode.type === CUSTOM_NOTE_NODE) diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 070748bab0ec9c..6a4629e9c86b7e 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -15,6 +15,7 @@ import VariableAssignerDefault from './nodes/variable-assigner/default' import AssignerDefault from './nodes/assigner/default' import EndNodeDefault from './nodes/end/default' import IterationDefault from './nodes/iteration/default' +import IterationStartDefault from './nodes/iteration-start/default' type NodesExtraData = { author: string @@ -89,6 +90,15 @@ export const NODES_EXTRA_DATA: Record = { getAvailableNextNodes: IterationDefault.getAvailableNextNodes, checkValid: IterationDefault.checkValid, }, + [BlockEnum.IterationStart]: { + author: 'Dify', + about: '', + availablePrevNodes: [], + availableNextNodes: [], + getAvailablePrevNodes: IterationStartDefault.getAvailablePrevNodes, + getAvailableNextNodes: IterationStartDefault.getAvailableNextNodes, + checkValid: IterationStartDefault.checkValid, + }, [BlockEnum.Code]: { author: 'Dify', about: '', @@ -222,6 +232,12 @@ export const NODES_INITIAL_DATA = { desc: '', ...IterationDefault.defaultValue, }, + [BlockEnum.IterationStart]: { + type: BlockEnum.IterationStart, + title: '', + desc: '', + ...IterationStartDefault.defaultValue, + }, [BlockEnum.Code]: { type: BlockEnum.Code, title: '', @@ -305,11 +321,13 @@ export const AUTO_LAYOUT_OFFSET = { export const ITERATION_NODE_Z_INDEX = 1 export const ITERATION_CHILDREN_Z_INDEX = 1002 export const ITERATION_PADDING = { - top: 85, + top: 65, right: 16, bottom: 20, left: 16, } +export const PARALLEL_LIMIT = 10 +export const PARALLEL_DEPTH_LIMIT = 3 export const RETRIEVAL_OUTPUT_STRUCT = `{ "content": "", @@ -412,4 +430,5 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' +export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 3645e184491d89..af2a1500baac0b 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -16,6 +16,7 @@ import { useReactFlow, useStoreApi, } from 'reactflow' +import { unionBy } from 'lodash-es' import type { ToolDefaultValue } from '../block-selector/types' import type { Edge, @@ -25,6 +26,7 @@ import type { import { BlockEnum } from '../types' import { useWorkflowStore } from '../store' import { + CUSTOM_EDGE, ITERATION_CHILDREN_Z_INDEX, ITERATION_PADDING, NODES_INITIAL_DATA, @@ -40,6 +42,7 @@ import { } from '../utils' import { CUSTOM_NOTE_NODE } from '../note-node/constants' import type { IterationNodeType } from '../nodes/iteration/types' +import { CUSTOM_ITERATION_START_NODE } from '../nodes/iteration-start/constants' import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types' import { useNodeIterationInteractions } from '../nodes/iteration/use-interactions' import { useWorkflowHistoryStore } from '../workflow-history-store' @@ -60,6 +63,7 @@ export const useNodesInteractions = () => { const { store: workflowHistoryStore } = useWorkflowHistoryStore() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { + checkNestedParallelLimit, getAfterNodesInSameBranch, } = useWorkflow() const { getNodesReadOnly } = useNodesReadOnly() @@ -79,7 +83,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart || node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_ITERATION_START_NODE || node.type === CUSTOM_NOTE_NODE) return dragNodeStartPosition.current = { x: node.position.x, y: node.position.y } @@ -89,7 +93,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart) + if (node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -156,7 +160,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -207,13 +211,30 @@ export const useNodesInteractions = () => { }) }) setEdges(newEdges) + const connectedEdges = getConnectedEdges([node], edges).filter(edge => edge.target === node.id) + + const targetNodes: Node[] = [] + for (let i = 0; i < connectedEdges.length; i++) { + const sourceConnectedEdges = getConnectedEdges([{ id: connectedEdges[i].source } as Node], edges).filter(edge => edge.source === connectedEdges[i].source && edge.sourceHandle === connectedEdges[i].sourceHandle) + targetNodes.push(...sourceConnectedEdges.map(edge => nodes.find(n => n.id === edge.target)!)) + } + const uniqTargetNodes = unionBy(targetNodes, 'id') + if (uniqTargetNodes.length > 1) { + const newNodes = produce(nodes, (draft) => { + draft.forEach((n) => { + if (uniqTargetNodes.some(targetNode => n.id === targetNode.id)) + n.data._inParallelHovering = true + }) + }) + setNodes(newNodes) + } }, [store, workflowStore, getNodesReadOnly]) const handleNodeLeave = useCallback((_, node) => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -229,6 +250,7 @@ export const useNodesInteractions = () => { const newNodes = produce(getNodes(), (draft) => { draft.forEach((node) => { node.data._isEntering = false + node.data._inParallelHovering = false }) }) setNodes(newNodes) @@ -287,6 +309,8 @@ export const useNodesInteractions = () => { }, [store, handleSyncWorkflowDraft]) const handleNodeClick = useCallback((_, node) => { + if (node.type === CUSTOM_ITERATION_START_NODE) + return handleNodeSelect(node.id) }, [handleNodeSelect]) @@ -314,25 +338,15 @@ export const useNodesInteractions = () => { if (targetNode?.parentId !== sourceNode?.parentId) return - if (targetNode?.data.isIterationStart) - return - if (sourceNode?.type === CUSTOM_NOTE_NODE || targetNode?.type === CUSTOM_NOTE_NODE) return - const needDeleteEdges = edges.filter((edge) => { - if ( - (edge.source === source && edge.sourceHandle === sourceHandle) - || (edge.target === target && edge.targetHandle === targetHandle && targetNode?.data.type !== BlockEnum.VariableAssigner && targetNode?.data.type !== BlockEnum.VariableAggregator) - ) - return true + if (edges.find(edge => edge.source === source && edge.sourceHandle === sourceHandle && edge.target === target && edge.targetHandle === targetHandle)) + return - return false - }) - const needDeleteEdgesIds = needDeleteEdges.map(edge => edge.id) const newEdge = { id: `${source}-${sourceHandle}-${target}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: source!, target: target!, sourceHandle, @@ -347,7 +361,6 @@ export const useNodesInteractions = () => { } const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( [ - ...needDeleteEdges.map(edge => ({ type: 'remove', edge })), { type: 'add', edge: newEdge }, ], nodes, @@ -362,19 +375,26 @@ export const useNodesInteractions = () => { } }) }) - setNodes(newNodes) const newEdges = produce(edges, (draft) => { - const filtered = draft.filter(edge => !needDeleteEdgesIds.includes(edge.id)) - - filtered.push(newEdge) - - return filtered + draft.push(newEdge) }) - setEdges(newEdges) - handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) - }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + } + else { + const { + setConnectingNodePayload, + setEnteringNodePayload, + } = workflowStore.getState() + setConnectingNodePayload(undefined) + setEnteringNodePayload(undefined) + } + }, [getNodesReadOnly, store, workflowStore, handleSyncWorkflowDraft, saveStateToHistory, checkNestedParallelLimit]) const handleNodeConnectStart = useCallback((_, { nodeId, handleType, handleId }) => { if (getNodesReadOnly()) @@ -393,14 +413,12 @@ export const useNodesInteractions = () => { return } - if (!node.data.isIterationStart) { - setConnectingNodePayload({ - nodeId, - nodeType: node.data.type, - handleType, - handleId, - }) - } + setConnectingNodePayload({ + nodeId, + nodeType: node.data.type, + handleType, + handleId, + }) } }, [store, workflowStore, getNodesReadOnly]) @@ -510,6 +528,12 @@ export const useNodesInteractions = () => { return handleNodeDelete(nodeId) } else { + if (iterationChildren.length === 1) { + handleNodeDelete(iterationChildren[0].id) + handleNodeDelete(nodeId) + + return + } const { setShowConfirm, showConfirm } = workflowStore.getState() if (!showConfirm) { @@ -541,14 +565,8 @@ export const useNodesInteractions = () => { } } - if (node.id === currentNode.parentId) { + if (node.id === currentNode.parentId) node.data._children = node.data._children?.filter(child => child !== nodeId) - - if (currentNode.id === (node as Node).data.start_node_id) { - (node as Node).data.start_node_id = ''; - (node as Node).data.startNodeType = undefined - } - } }) draft.splice(currentNodeIndex, 1) }) @@ -559,7 +577,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - if (currentNode.type === 'custom-note') + if (currentNode.type === CUSTOM_NOTE_NODE) saveStateToHistory(WorkflowHistoryEvent.NoteDelete) else @@ -591,7 +609,10 @@ export const useNodesInteractions = () => { } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -627,7 +648,7 @@ export const useNodesInteractions = () => { const newEdge: Edge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -662,8 +683,10 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) + if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { const { setShowAssignVariablePopup } = workflowStore.getState() @@ -687,7 +710,14 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } } if (!prevNodeId && nextNodeId) { const nextNodeIndex = nodes.findIndex(node => node.id === nextNodeId) @@ -706,15 +736,13 @@ export const useNodesInteractions = () => { newNode.data.iteration_id = nextNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX } - if (nextNode.data.isIterationStart) - newNode.data.isIterationStart = true let newEdge if ((nodeType !== BlockEnum.IfElse) && (nodeType !== BlockEnum.QuestionClassifier)) { newEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -763,13 +791,11 @@ export const useNodesInteractions = () => { node.data.start_node_id = newNode.id node.data.startNodeType = newNode.data.type } - - if (node.id === nextNodeId && node.data.isIterationStart) - node.data.isIterationStart = false }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) if (newEdge) { const newEdges = produce(edges, (draft) => { draft.forEach((item) => { @@ -780,7 +806,21 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } + } + else { + if (checkNestedParallelLimit(newNodes, edges)) + setNodes(newNodes) + + else + return false } } if (prevNodeId && nextNodeId) { @@ -804,7 +844,7 @@ export const useNodesInteractions = () => { const currentEdgeIndex = edges.findIndex(edge => edge.source === prevNodeId && edge.target === nextNodeId) const newPrevEdge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -822,7 +862,7 @@ export const useNodesInteractions = () => { if (nodeType !== BlockEnum.IfElse && nodeType !== BlockEnum.QuestionClassifier) { newNextEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -865,6 +905,8 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { @@ -898,7 +940,7 @@ export const useNodesInteractions = () => { } handleSyncWorkflowDraft() saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch]) + }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit]) const handleNodeChange = useCallback(( currentNodeId: string, @@ -919,7 +961,10 @@ export const useNodesInteractions = () => { const currentNode = nodes.find(node => node.id === currentNodeId)! const connectedEdges = getConnectedEdges([currentNode], edges) const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newCurrentNode = generateNewNode({ + const { + newNode: newCurrentNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -929,7 +974,6 @@ export const useNodesInteractions = () => { selected: currentNode.data.selected, isInIteration: currentNode.data.isInIteration, iteration_id: currentNode.data.iteration_id, - isIterationStart: currentNode.data.isIterationStart, }, position: { x: currentNode.position.x, @@ -955,18 +999,12 @@ export const useNodesInteractions = () => { ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], } } - if (node.id === currentNode.parentId && currentNode.data.isIterationStart) { - node.data._children = [ - newCurrentNode.id, - ...(node.data._children || []), - ].filter(child => child !== currentNodeId) - node.data.start_node_id = newCurrentNode.id - node.data.startNodeType = newCurrentNode.data.type - } }) const index = draft.findIndex(node => node.id === currentNodeId) draft.splice(index, 1, newCurrentNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) const newEdges = produce(edges, (draft) => { @@ -1011,7 +1049,7 @@ export const useNodesInteractions = () => { }, [store]) const handleNodeContextMenu = useCallback((e: MouseEvent, node: Node) => { - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return e.preventDefault() @@ -1041,7 +1079,7 @@ export const useNodesInteractions = () => { if (nodeId) { // If nodeId is provided, copy that specific node - const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start) + const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start && node.type !== CUSTOM_ITERATION_START_NODE) if (nodeToCopy) setClipboardElements([nodeToCopy]) } @@ -1087,7 +1125,10 @@ export const useNodesInteractions = () => { clipboardElements.forEach((nodeToPaste, index) => { const nodeType = nodeToPaste.data.type - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ type: nodeToPaste.type, data: { ...NODES_INITIAL_DATA[nodeType], @@ -1106,24 +1147,17 @@ export const useNodesInteractions = () => { zIndex: nodeToPaste.zIndex, }) newNode.id = newNode.id + index - - // If only the iteration start node is copied, remove the isIterationStart flag // This new node is movable and can be placed anywhere - if (clipboardElements.length === 1 && newNode.data.isIterationStart) - newNode.data.isIterationStart = false - let newChildren: Node[] = [] if (nodeToPaste.data.type === BlockEnum.Iteration) { - newNode.data._children = []; - (newNode.data as IterationNodeType).start_node_id = '' + newIterationStartNode!.parentId = newNode.id; + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode!.id newChildren = handleNodeIterationChildrenCopy(nodeToPaste.id, newNode.id) - newChildren.forEach((child) => { newNode.data._children?.push(child.id) - if (child.data.isIterationStart) - (newNode.data as IterationNodeType).start_node_id = child.id }) + newChildren.push(newIterationStartNode!) } nodesToPaste.push(newNode) @@ -1230,6 +1264,42 @@ export const useNodesInteractions = () => { saveStateToHistory(WorkflowHistoryEvent.NodeResize) }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + const handleNodeDisconnect = useCallback((nodeId: string) => { + if (getNodesReadOnly()) + return + + const { + getNodes, + setNodes, + edges, + setEdges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId)! + const connectedEdges = getConnectedEdges([currentNode], edges) + const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( + connectedEdges.map(edge => ({ type: 'remove', edge })), + nodes, + ) + const newNodes = produce(nodes, (draft: Node[]) => { + draft.forEach((node) => { + if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) { + node.data = { + ...node.data, + ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], + } + } + }) + }) + setNodes(newNodes) + const newEdges = produce(edges, (draft) => { + return draft.filter(edge => !connectedEdges.find(connectedEdge => connectedEdge.id === edge.id)) + }) + setEdges(newEdges) + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.EdgeDelete) + }, [store, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory]) + const handleHistoryBack = useCallback(() => { if (getNodesReadOnly() || getWorkflowReadOnly()) return @@ -1282,6 +1352,7 @@ export const useNodesInteractions = () => { handleNodesDuplicate, handleNodesDelete, handleNodeResize, + handleNodeDisconnect, handleHistoryBack, handleHistoryForward, } diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index a872aee398d66d..e1da503f3825f1 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -1,5 +1,6 @@ import { useCallback } from 'react' import { + getIncomers, useReactFlow, useStoreApi, } from 'reactflow' @@ -8,6 +9,7 @@ import { v4 as uuidV4 } from 'uuid' import { usePathname } from 'next/navigation' import { useWorkflowStore } from '../store' import { useNodesSyncDraft } from '../hooks' +import type { Node } from '../types' import { NodeRunningStatus, WorkflowRunningStatus, @@ -140,9 +142,6 @@ export const useWorkflowRun = () => { resultText: '', }) - let isInIteration = false - let iterationLength = 0 - let ttsUrl = '' let ttsIsPublic = false if (params.token) { @@ -186,7 +185,7 @@ export const useWorkflowRun = () => { draft.forEach((edge) => { edge.data = { ...edge.data, - _run: false, + _runned: false, } }) }) @@ -249,19 +248,20 @@ export const useWorkflowRun = () => { setEdges, transform, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const node = nodes.find(node => node.id === data.node_id) + if (node?.parentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ + const iterations = tracing.find(trace => trace.node_id === node?.parentId) + const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] + currIteration?.push({ ...data, status: NodeRunningStatus.Running, } as any) })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -288,11 +288,12 @@ export const useWorkflowRun = () => { draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running }) setNodes(newNodes) + const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id) const newEdges = produce(edges, (draft) => { - const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) - - if (edge) - edge.data = { ...edge.data, _run: true } as any + draft.forEach((edge) => { + if (edge.target === data.node_id && incomeNodesId.includes(edge.source)) + edge.data = { ...edge.data, _runned: true } as any + }) }) setEdges(newEdges) } @@ -309,25 +310,46 @@ export const useWorkflowRun = () => { getNodes, setNodes, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId + if (nodeParentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - const nodeInfo = currIteration[currIteration.length - 1] - - currIteration[currIteration.length - 1] = { - ...nodeInfo, - ...data, - status: NodeRunningStatus.Succeeded, - } as any + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] + + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) + if (data.status === NodeRunningStatus.Succeeded) { + if (nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + } })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) - + const currentIndex = draft.tracing!.findIndex((trace) => { + if (!trace.execution_metadata?.parallel_id) + return trace.node_id === data.node_id + return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id + }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras @@ -337,16 +359,14 @@ export const useWorkflowRun = () => { } as any } })) - const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - currentNode.data._runningStatus = data.status as any }) setNodes(newNodes) - prevNodeId = data.node_id } + if (onNodeFinished) onNodeFinished(params) }, @@ -371,8 +391,6 @@ export const useWorkflowRun = () => { details: [], } as any) })) - isInIteration = true - iterationLength = data.metadata.iterator_length const { setViewport, @@ -398,7 +416,7 @@ export const useWorkflowRun = () => { const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) if (edge) - edge.data = { ...edge.data, _run: true } as any + edge.data = { ...edge.data, _runned: true } as any }) setEdges(newEdges) @@ -418,13 +436,13 @@ export const useWorkflowRun = () => { } = store.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const iteration = draft.tracing![draft.tracing!.length - 1] - if (iteration.details!.length >= iterationLength) - return - - iteration.details!.push([]) + const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) + if (iteration) { + if (iteration.details!.length >= iteration.metadata.iterator_length!) + return + } + iteration?.details!.push([]) })) - const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -450,13 +468,14 @@ export const useWorkflowRun = () => { const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], - ...data, - status: NodeRunningStatus.Succeeded, - } as any + const currIterationNode = tracing.find(trace => trace.node_id === data.node_id) + if (currIterationNode) { + Object.assign(currIterationNode, { + ...data, + status: NodeRunningStatus.Succeeded, + }) + } })) - isInIteration = false const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -470,6 +489,12 @@ export const useWorkflowRun = () => { if (onIterationFinish) onIterationFinish(params) }, + onParallelBranchStarted: (params) => { + // console.log(params, 'parallel start') + }, + onParallelBranchFinished: (params) => { + // console.log(params, 'finished') + }, onTextChunk: (params) => { const { data: { text } } = params const { diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow/hooks/use-workflow-template.ts index 3af3f733f1e43f..e36f0b61f9b26f 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow/hooks/use-workflow-template.ts @@ -10,13 +10,13 @@ export const useWorkflowTemplate = () => { const isChatMode = useIsChatMode() const nodesInitialData = useNodesInitialData() - const startNode = generateNewNode({ + const { newNode: startNode } = generateNewNode({ data: nodesInitialData.start, position: START_INITIAL_POSITION, }) if (isChatMode) { - const llmNode = generateNewNode({ + const { newNode: llmNode } = generateNewNode({ id: 'llm', data: { ...nodesInitialData.llm, @@ -31,7 +31,7 @@ export const useWorkflowTemplate = () => { }, } as any) - const answerNode = generateNewNode({ + const { newNode: answerNode } = generateNewNode({ id: 'answer', data: { ...nodesInitialData.answer, diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index cfff4220fa81ab..460e36ae60ab59 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -6,6 +6,7 @@ import { } from 'react' import dayjs from 'dayjs' import { uniqBy } from 'lodash-es' +import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { getIncomers, @@ -29,6 +30,11 @@ import { useWorkflowStore, } from '../store' import { + getParallelInfo, +} from '../utils' +import { + PARALLEL_DEPTH_LIMIT, + PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' import { CUSTOM_NOTE_NODE } from '../note-node/constants' @@ -50,6 +56,7 @@ import { } from '@/service/tools' import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' +import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -58,6 +65,7 @@ export const useIsChatMode = () => { } export const useWorkflow = () => { + const { t } = useTranslation() const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() @@ -77,7 +85,7 @@ export const useWorkflow = () => { const currentNode = nodes.find(node => node.id === nodeId) if (currentNode?.parentId) - startNode = nodes.find(node => node.parentId === currentNode.parentId && node.data.isIterationStart) + startNode = nodes.find(node => node.parentId === currentNode.parentId && node.type === CUSTOM_ITERATION_START_NODE) if (!startNode) return [] @@ -275,6 +283,45 @@ export const useWorkflow = () => { return isUsed }, [isVarUsedInNodes]) + const checkParallelLimit = useCallback((nodeId: string) => { + const { + getNodes, + edges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId)! + const sourceNodeOutgoers = getOutgoers(currentNode, nodes, edges) + if (sourceNodeOutgoers.length > PARALLEL_LIMIT - 1) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) + return false + } + + return true + }, [store, workflowStore, t]) + + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { + const { + parallelList, + hasAbnormalEdges, + } = getParallelInfo(nodes, edges, parentNodeId) + + if (hasAbnormalEdges) + return false + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > PARALLEL_DEPTH_LIMIT) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) + return false + } + } + + return true + }, [t, workflowStore]) + const isValidConnection = useCallback(({ source, target }: Connection) => { const { edges, @@ -284,12 +331,15 @@ export const useWorkflow = () => { const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! - if (targetNode.data.isIterationStart) + if (!checkParallelLimit(source!)) return false if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) return false + if (sourceNode.parentId !== targetNode.parentId) + return false + if (sourceNode && targetNode) { const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start] @@ -316,7 +366,7 @@ export const useWorkflow = () => { } return !hasCycle(targetNode) - }, [store, nodesExtraData]) + }, [store, nodesExtraData, checkParallelLimit]) const formatTimeFromNow = useCallback((time: number) => { return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() @@ -339,6 +389,8 @@ export const useWorkflow = () => { isVarUsedInNodes, removeUsedVarInNodes, isNodeVarsUsedInNodes, + checkParallelLimit, + checkNestedParallelLimit, isValidConnection, formatTimeFromNow, getNode, diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index d96faa8677d453..cdccd60a3b5a16 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -55,6 +55,8 @@ import Header from './header' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' +import CustomIterationStartNode from './nodes/iteration-start' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import Operator from './operator' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' @@ -67,6 +69,7 @@ import NodeContextmenu from './node-contextmenu' import SyncingDataModal from './syncing-data-modal' import UpdateDSLModal from './update-dsl-modal' import DSLExportConfirmModal from './dsl-export-confirm-modal' +import LimitTips from './limit-tips' import { useStore, useWorkflowStore, @@ -92,6 +95,7 @@ import Confirm from '@/app/components/base/confirm' const nodeTypes = { [CUSTOM_NODE]: CustomNode, [CUSTOM_NOTE_NODE]: CustomNoteNode, + [CUSTOM_ITERATION_START_NODE]: CustomIterationStartNode, } const edgeTypes = { [CUSTOM_NODE]: CustomEdge, @@ -317,6 +321,7 @@ const Workflow: FC = memo(({ /> ) } + { + const showTips = useStore(s => s.showTips) + const setShowTips = useStore(s => s.setShowTips) + + if (!showTips) + return null + + return ( +
+
+
+ +
+
+ {showTips} +
+ setShowTips('')} + > + + +
+ ) +} + +export default LimitTips diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 0ab0c8e39e2f6e..6e3988eecb64b0 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -21,13 +21,13 @@ type AddProps = { nodeId: string nodeData: CommonNodeType sourceHandle: string - branchName?: string + isParallel?: boolean } const Add = ({ nodeId, nodeData, sourceHandle, - branchName, + isParallel, }: AddProps) => { const { t } = useTranslation() const { handleNodeAdd } = useNodesInteractions() @@ -57,23 +57,19 @@ const Add = ({ ${nodesReadOnly && '!cursor-not-allowed'} `} > - { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - }
- {t('workflow.panel.selectNextStep')} +
+ { + isParallel + ? t('workflow.common.addParallelNode') + : t('workflow.panel.selectNextStep') + } +
) - }, [branchName, t, nodesReadOnly]) + }, [t, nodesReadOnly, isParallel]) return ( { + return ( +
+ { + branchName && ( +
+ {branchName} +
+ ) + } + { + nextNodes.map(nextNode => ( + + )) + } + +
+ ) +} + +export default Container diff --git a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx index 261eb3fac71d95..d980eb284e2ee9 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx @@ -1,4 +1,5 @@ -import { memo } from 'react' +import { memo, useMemo } from 'react' +import { useTranslation } from 'react-i18next' import { getConnectedEdges, getOutgoers, @@ -8,13 +9,11 @@ import { import { useToolIcon } from '../../../../hooks' import BlockIcon from '../../../../block-icon' import type { - Branch, Node, } from '../../../../types' import { BlockEnum } from '../../../../types' -import Add from './add' -import Item from './item' import Line from './line' +import Container from './container' type NextStepProps = { selectedNode: Node @@ -22,15 +21,33 @@ type NextStepProps = { const NextStep = ({ selectedNode, }: NextStepProps) => { + const { t } = useTranslation() const data = selectedNode.data const toolIcon = useToolIcon(data) const store = useStoreApi() - const branches = data._targetBranches || [] + const branches = useMemo(() => { + return data._targetBranches || [] + }, [data]) const nodeWithBranches = data.type === BlockEnum.IfElse || data.type === BlockEnum.QuestionClassifier const edges = useEdges() const outgoers = getOutgoers(selectedNode as Node, store.getState().getNodes(), edges) const connectedEdges = getConnectedEdges([selectedNode] as Node[], edges).filter(edge => edge.source === selectedNode!.id) + const branchesOutgoers = useMemo(() => { + if (!branches?.length) + return [] + + return branches.map((branch) => { + const connected = connectedEdges.filter(edge => edge.sourceHandle === branch.id) + const nextNodes = connected.map(edge => outgoers.find(outgoer => outgoer.id === edge.target)!) + + return { + branch, + nextNodes, + } + }) + }, [branches, connectedEdges, outgoers]) + return (
@@ -39,59 +56,32 @@ const NextStep = ({ toolIcon={toolIcon} />
- -
+ item.nextNodes.length + 1) : [1]} + /> +
{ - !nodeWithBranches && !!outgoers.length && ( - - ) - } - { - !nodeWithBranches && !outgoers.length && ( - ) } { - !!branches?.length && nodeWithBranches && ( - branches.map((branch: Branch) => { - const connected = connectedEdges.find(edge => edge.sourceHandle === branch.id) - const target = outgoers.find(outgoer => outgoer.id === connected?.target) - + nodeWithBranches && ( + branchesOutgoers.map((item, index) => { return ( -
- { - connected && ( - - ) - } - { - !connected && ( - - ) - } -
+ ) }) ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx index b806de5684c336..db3748abd905d4 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx @@ -1,94 +1,82 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' -import { intersection } from 'lodash-es' +import Operator from './operator' import type { CommonNodeType, - OnSelectBlock, } from '@/app/components/workflow/types' import BlockIcon from '@/app/components/workflow/block-icon' -import BlockSelector from '@/app/components/workflow/block-selector' import { - useAvailableBlocks, useNodesInteractions, useNodesReadOnly, useToolIcon, } from '@/app/components/workflow/hooks' import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' type ItemProps = { nodeId: string sourceHandle: string - branchName?: string data: CommonNodeType } const Item = ({ nodeId, sourceHandle, - branchName, data, }: ItemProps) => { const { t } = useTranslation() - const { handleNodeChange } = useNodesInteractions() + const [open, setOpen] = useState(false) const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeSelect } = useNodesInteractions() const toolIcon = useToolIcon(data) - const { - availablePrevBlocks, - availableNextBlocks, - } = useAvailableBlocks(data.type, data.isInIteration) - const handleSelect = useCallback((type, toolDefaultValue) => { - handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) - }, [nodeId, sourceHandle, handleNodeChange]) - const renderTrigger = useCallback((open: boolean) => { - return ( - - ) - }, [t]) + const handleOpenChange = useCallback((v: boolean) => { + setOpen(v) + }, []) return (
- { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - } -
{data.title}
+
+ {data.title} +
{ !nodesReadOnly && ( - item !== data.type)} - /> + <> + +
+ +
+ ) }
diff --git a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx index b06a02c15874a0..3a4430cb5de9e3 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx @@ -1,56 +1,70 @@ import { memo } from 'react' type LineProps = { - linesNumber: number + list: number[] } const Line = ({ - linesNumber, + list, }: LineProps) => { - const svgHeight = linesNumber * 36 + (linesNumber - 1) * 12 + const listHeight = list.map((item) => { + return item * 36 + (item - 1) * 2 + 12 + 6 + }) + const processedList = listHeight.map((item, index) => { + if (index === 0) + return item + + return listHeight.slice(0, index).reduce((acc, cur) => acc + cur, 0) + item + }) + const processedListLength = processedList.length + const svgHeight = processedList[processedListLength - 1] + (processedListLength - 1) * 8 return ( { - Array(linesNumber).fill(0).map((_, index) => ( - - { - index === 0 && ( - <> + processedList.map((item, index) => { + const prevItem = index > 0 ? processedList[index - 1] : 0 + const space = prevItem + index * 8 + 16 + return ( + + { + index === 0 && ( + <> + + + + ) + } + { + index > 0 && ( - - - ) - } - { - index > 0 && ( - - ) - } - - - )) + ) + } + + + ) + }) } ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx new file mode 100644 index 00000000000000..ad6c7abd0ce796 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx @@ -0,0 +1,129 @@ +import { + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiMoreFill } from '@remixicon/react' +import { intersection } from 'lodash-es' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import BlockSelector from '@/app/components/workflow/block-selector' +import { + useAvailableBlocks, + useNodesInteractions, +} from '@/app/components/workflow/hooks' +import type { + CommonNodeType, + OnSelectBlock, +} from '@/app/components/workflow/types' + +type ChangeItemProps = { + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const ChangeItem = ({ + data, + nodeId, + sourceHandle, +}: ChangeItemProps) => { + const { t } = useTranslation() + + const { handleNodeChange } = useNodesInteractions() + const { + availablePrevBlocks, + availableNextBlocks, + } = useAvailableBlocks(data.type, data.isInIteration) + + const handleSelect = useCallback((type, toolDefaultValue) => { + handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) + }, [nodeId, sourceHandle, handleNodeChange]) + + const renderTrigger = useCallback(() => { + return ( +
+ {t('workflow.panel.change')} +
+ ) + }, [t]) + + return ( + item !== data.type)} + /> + ) +} + +type OperatorProps = { + open: boolean + onOpenChange: (v: boolean) => void + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const Operator = ({ + open, + onOpenChange, + data, + nodeId, + sourceHandle, +}: OperatorProps) => { + const { t } = useTranslation() + const { + handleNodeDelete, + handleNodeDisconnect, + } = useNodesInteractions() + + return ( + + onOpenChange(!open)}> + + + +
+
+ +
handleNodeDisconnect(nodeId)} + > + {t('workflow.common.disconnect')} +
+
+
+
handleNodeDelete(nodeId)} + > + {t('common.operation.delete')} +
+
+
+
+
+ ) +} + +export default Operator diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index 56870f79d6de63..bcd03d6a5ec56e 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -9,16 +9,22 @@ import { Handle, Position, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { BlockEnum } from '../../../types' import type { Node } from '../../../types' import BlockSelector from '../../../block-selector' import type { ToolDefaultValue } from '../../../block-selector/types' import { useAvailableBlocks, + useIsChatMode, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '../../../hooks' -import { useStore } from '../../../store' +import { + useStore, +} from '../../../store' +import Tooltip from '@/app/components/base/tooltip' type NodeHandleProps = { handleId: string @@ -38,9 +44,7 @@ export const NodeTargetHandle = memo(({ const { getNodesReadOnly } = useNodesReadOnly() const connected = data._connectedTargetHandleIds?.includes(handleId) const { availablePrevBlocks } = useAvailableBlocks(data.type, data.isInIteration) - const isConnectable = !!availablePrevBlocks.length && ( - !data.isIterationStart - ) + const isConnectable = !!availablePrevBlocks.length const handleOpenChange = useCallback((v: boolean) => { setOpen(v) @@ -112,12 +116,15 @@ export const NodeSourceHandle = memo(({ handleClassName, nodeSelectorClassName, }: NodeHandleProps) => { + const { t } = useTranslation() const notInitialWorkflow = useStore(s => s.notInitialWorkflow) const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { getNodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration) const isConnectable = !!availableNextBlocks.length + const isChatMode = useIsChatMode() + const { checkParallelLimit } = useWorkflow() const connected = data._connectedSourceHandleIds?.includes(handleId) const handleOpenChange = useCallback((v: boolean) => { @@ -125,9 +132,9 @@ export const NodeSourceHandle = memo(({ }, []) const handleHandleClick = useCallback((e: MouseEvent) => { e.stopPropagation() - if (!connected) + if (checkParallelLimit(id)) setOpen(v => !v) - }, [connected]) + }, [checkParallelLimit, id]) const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { handleNodeAdd( { @@ -142,12 +149,25 @@ export const NodeSourceHandle = memo(({ }, [handleNodeAdd, id, handleId]) useEffect(() => { - if (notInitialWorkflow && data.type === BlockEnum.Start) + if (notInitialWorkflow && data.type === BlockEnum.Start && !isChatMode) setOpen(true) - }, [notInitialWorkflow, data.type]) + }, [notInitialWorkflow, data.type, isChatMode]) return ( - <> + +
+ {t('workflow.common.parallelTip.click.title')} + {t('workflow.common.parallelTip.click.desc')} +
+
+ {t('workflow.common.parallelTip.drag.title')} + {t('workflow.common.parallelTip.drag.desc')} +
+
+ )} + > { - !connected && isConnectable && !getNodesReadOnly() && ( + isConnectable && !getNodesReadOnly() && ( - + ) }) NodeSourceHandle.displayName = 'NodeSourceHandle' diff --git a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx index 4c83bea8d6e761..a8e7a9aa11294b 100644 --- a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx @@ -28,8 +28,8 @@ const NodeResizer = ({ nodeId, nodeData, icon = , - minWidth = 272, - minHeight = 176, + minWidth = 258, + minHeight = 152, maxWidth, }: NodeResizerProps) => { const { handleNodeResize } = useNodesInteractions() diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 0b45c808881c48..bd5921c7356442 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -14,6 +14,7 @@ import { RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' +import { useTranslation } from 'react-i18next' import type { NodeProps } from '../../types' import { BlockEnum, @@ -43,6 +44,7 @@ const BaseNode: FC = ({ data, children, }) => { + const { t } = useTranslation() const nodeRef = useRef(null) const { nodesReadOnly } = useNodesReadOnly() const { handleNodeIterationChildSizeChange } = useNodeIterationInteractions() @@ -80,6 +82,7 @@ const BaseNode: FC = ({ className={cn( 'flex border-[2px] rounded-2xl', showSelectedBorder ? 'border-components-option-card-option-selected-border' : 'border-transparent', + !showSelectedBorder && data._inParallelHovering && 'border-workflow-block-border-highlight', )} ref={nodeRef} style={{ @@ -100,6 +103,13 @@ const BaseNode: FC = ({ data._isBundled && '!shadow-lg', )} > + { + data._inParallelHovering && ( +
+ {t('workflow.common.parallelRun')} +
+ ) + } { data._showAddVariablePopup && ( { ComparisonOperator.isNot, ComparisonOperator.empty, ComparisonOperator.notEmpty, - ComparisonOperator.regexMatch, ] case VarType.number: return [ diff --git a/web/app/components/workflow/nodes/iteration-start/constants.ts b/web/app/components/workflow/nodes/iteration-start/constants.ts new file mode 100644 index 00000000000000..94e3ccbd90414c --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/constants.ts @@ -0,0 +1 @@ +export const CUSTOM_ITERATION_START_NODE = 'custom-iteration-start' diff --git a/web/app/components/workflow/nodes/iteration-start/default.ts b/web/app/components/workflow/nodes/iteration-start/default.ts new file mode 100644 index 00000000000000..d98efa7ba2f2a3 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/default.ts @@ -0,0 +1,21 @@ +import type { NodeDefault } from '../../types' +import type { IterationStartNodeType } from './types' +import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' + +const nodeDefault: NodeDefault = { + defaultValue: {}, + getAvailablePrevNodes() { + return [] + }, + getAvailableNextNodes(isChatMode: boolean) { + const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + return nodes + }, + checkValid() { + return { + isValid: true, + } + }, +} + +export default nodeDefault diff --git a/web/app/components/workflow/nodes/iteration-start/index.tsx b/web/app/components/workflow/nodes/iteration-start/index.tsx new file mode 100644 index 00000000000000..9d7ac1f9050085 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/index.tsx @@ -0,0 +1,42 @@ +import { memo } from 'react' +import { useTranslation } from 'react-i18next' +import type { NodeProps } from 'reactflow' +import { RiHome5Fill } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' +import { NodeSourceHandle } from '@/app/components/workflow/nodes/_base/components/node-handle' + +const IterationStartNode = ({ id, data }: NodeProps) => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+ +
+ ) +} + +export const IterationStartNodeDumb = () => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+
+ ) +} + +export default memo(IterationStartNode) diff --git a/web/app/components/workflow/nodes/iteration-start/types.ts b/web/app/components/workflow/nodes/iteration-start/types.ts new file mode 100644 index 00000000000000..319cce0bc24023 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/types.ts @@ -0,0 +1,3 @@ +import type { CommonNodeType } from '@/app/components/workflow/types' + +export type IterationStartNodeType = CommonNodeType diff --git a/web/app/components/workflow/nodes/iteration/add-block.tsx b/web/app/components/workflow/nodes/iteration/add-block.tsx index fd8480b7dff706..07e2b5daf0e168 100644 --- a/web/app/components/workflow/nodes/iteration/add-block.tsx +++ b/web/app/components/workflow/nodes/iteration/add-block.tsx @@ -2,87 +2,49 @@ import { memo, useCallback, } from 'react' -import produce from 'immer' import { RiAddLine, } from '@remixicon/react' -import { useStoreApi } from 'reactflow' import { useTranslation } from 'react-i18next' import { - generateNewNode, -} from '../../utils' -import { - WorkflowHistoryEvent, useAvailableBlocks, + useNodesInteractions, useNodesReadOnly, - useWorkflowHistory, } from '../../hooks' -import { NODES_INITIAL_DATA } from '../../constants' -import InsertBlock from './insert-block' import type { IterationNodeType } from './types' import cn from '@/utils/classnames' import BlockSelector from '@/app/components/workflow/block-selector' -import { IterationStart } from '@/app/components/base/icons/src/vender/workflow' import type { OnSelectBlock, } from '@/app/components/workflow/types' import { BlockEnum, } from '@/app/components/workflow/types' -import Tooltip from '@/app/components/base/tooltip' type AddBlockProps = { iterationNodeId: string iterationNodeData: IterationNodeType } const AddBlock = ({ - iterationNodeId, iterationNodeData, }: AddBlockProps) => { const { t } = useTranslation() - const store = useStoreApi() const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeAdd } = useNodesInteractions() const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, true) - const { availablePrevBlocks } = useAvailableBlocks(iterationNodeData.startNodeType, true) - const { saveStateToHistory } = useWorkflowHistory() const handleSelect = useCallback((type, toolDefaultValue) => { - const { - getNodes, - setNodes, - } = store.getState() - const nodes = getNodes() - const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ - data: { - ...NODES_INITIAL_DATA[type], - title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), - ...(toolDefaultValue || {}), - isIterationStart: true, - isInIteration: true, - iteration_id: iterationNodeId, + handleNodeAdd( + { + nodeType: type, + toolDefaultValue, }, - position: { - x: 117, - y: 85, + { + prevNodeId: iterationNodeData.start_node_id, + prevNodeSourceHandle: 'source', }, - zIndex: 1001, - parentId: iterationNodeId, - extent: 'parent', - }) - const newNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - if (node.id === iterationNodeId) { - node.data._children = [newNode.id] - node.data.start_node_id = newNode.id - node.data.startNodeType = newNode.data.type - } - }) - draft.push(newNode) - }) - setNodes(newNodes) - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [store, t, iterationNodeId, saveStateToHistory]) + ) + }, [handleNodeAdd, iterationNodeData.start_node_id]) const renderTriggerElement = useCallback((open: boolean) => { return ( @@ -98,35 +60,18 @@ const AddBlock = ({ }, [nodesReadOnly, t]) return ( -
- -
- -
-
+
- { - iterationNodeData.startNodeType && ( - - ) - }
- { - !iterationNodeData.startNodeType && ( - - ) - } +
) } diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 43f8a751acbbea..3afa52d06ec62a 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -9,6 +9,7 @@ const nodeDefault: NodeDefault = { start_node_id: '', iterator_selector: [], output_selector: [], + _children: [], }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/iteration/insert-block.tsx b/web/app/components/workflow/nodes/iteration/insert-block.tsx deleted file mode 100644 index d041fe1c746812..00000000000000 --- a/web/app/components/workflow/nodes/iteration/insert-block.tsx +++ /dev/null @@ -1,61 +0,0 @@ -import { - memo, - useCallback, - useState, -} from 'react' -import { useNodesInteractions } from '../../hooks' -import type { - BlockEnum, - OnSelectBlock, -} from '../../types' -import BlockSelector from '../../block-selector' -import cn from '@/utils/classnames' - -type InsertBlockProps = { - startNodeId: string - availableBlocksTypes: BlockEnum[] -} -const InsertBlock = ({ - startNodeId, - availableBlocksTypes, -}: InsertBlockProps) => { - const [open, setOpen] = useState(false) - const { handleNodeAdd } = useNodesInteractions() - - const handleOpenChange = useCallback((v: boolean) => { - setOpen(v) - }, []) - const handleInsert = useCallback((nodeType, toolDefaultValue) => { - handleNodeAdd( - { - nodeType, - toolDefaultValue, - }, - { - nextNodeId: startNodeId, - nextNodeTargetHandle: 'target', - }, - ) - }, [startNodeId, handleNodeAdd]) - - return ( -
- 'hover:scale-125 transition-all'} - /> -
- ) -} - -export default memo(InsertBlock) diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index f4520402f35312..48a005a261df61 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,6 +8,7 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' @@ -29,7 +30,7 @@ const Node: FC> = ({ return (
> = ({ size={2 / zoom} color='#E4E5E7' /> - + { + data._isCandidate && ( + + ) + } + { + data._children!.length === 1 && ( + + ) + }
) } diff --git a/web/app/components/workflow/nodes/iteration/use-interactions.ts b/web/app/components/workflow/nodes/iteration/use-interactions.ts index 219c8e731f4b61..f8e3640cc4808c 100644 --- a/web/app/components/workflow/nodes/iteration/use-interactions.ts +++ b/web/app/components/workflow/nodes/iteration/use-interactions.ts @@ -11,6 +11,7 @@ import { ITERATION_PADDING, NODES_INITIAL_DATA, } from '../../constants' +import { CUSTOM_ITERATION_START_NODE } from '../iteration-start/constants' export const useNodeIterationInteractions = () => { const { t } = useTranslation() @@ -107,12 +108,12 @@ export const useNodeIterationInteractions = () => { const handleNodeIterationChildrenCopy = useCallback((nodeId: string, newNodeId: string) => { const { getNodes } = store.getState() const nodes = getNodes() - const childrenNodes = nodes.filter(n => n.parentId === nodeId) + const childrenNodes = nodes.filter(n => n.parentId === nodeId && n.type !== CUSTOM_ITERATION_START_NODE) return childrenNodes.map((child, index) => { const childNodeType = child.data.type as BlockEnum const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[childNodeType], ...child.data, @@ -121,6 +122,7 @@ export const useNodeIterationInteractions = () => { _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${childNodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${childNodeType}`), + iteration_id: newNodeId, }, position: child.position, positionAbsolute: child.positionAbsolute, diff --git a/web/app/components/workflow/operator/add-block.tsx b/web/app/components/workflow/operator/add-block.tsx index 48222cc5280624..388fbc053f0713 100644 --- a/web/app/components/workflow/operator/add-block.tsx +++ b/web/app/components/workflow/operator/add-block.tsx @@ -55,7 +55,7 @@ const AddBlock = ({ } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[type], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), diff --git a/web/app/components/workflow/operator/hooks.ts b/web/app/components/workflow/operator/hooks.ts index 5b142114977f38..edec10bda781b8 100644 --- a/web/app/components/workflow/operator/hooks.ts +++ b/web/app/components/workflow/operator/hooks.ts @@ -11,7 +11,7 @@ export const useOperator = () => { const { userProfile } = useAppContext() const handleAddNote = useCallback(() => { - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ type: CUSTOM_NOTE_NODE, data: { title: '', diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 54d3915a1362b0..51a018bcb15b43 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -180,8 +180,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) const bodyParams = { @@ -317,11 +315,11 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, - onIterationNext: () => { + onIterationNext: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] + const iterations = tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! iterations.details!.push([]) handleUpdateChatList(produce(chatListRef.current, (draft) => { @@ -331,9 +329,10 @@ export const useChat = ( }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - tracing[tracing.length - 1] = { - ...iterations, + const iterationsIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationsIndex] = { + ...tracing[iterationsIndex], ...data, status: NodeRunningStatus.Succeeded, } as any @@ -341,67 +340,45 @@ export const useChat = ( const currentIndex = draft.length - 1 draft[currentIndex] = responseItem })) - - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - responseItem.workflowProcess!.tracing!.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + responseItem.workflowProcess!.tracing!.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, onNodeFinished: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration[currIteration.length - 1] = { - ...data, - status: NodeRunningStatus.Succeeded, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) - responseItem.workflowProcess!.tracing[currentIndex] = { - ...(responseItem.workflowProcess!.tracing[currentIndex].extras - ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } - : {}), - ...data, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + responseItem.workflowProcess!.tracing[currentIndex] = { + ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras + ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } + : {}), + ...data, + } as any + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, }, ) diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index b9b77fc4a3ba1a..331ef1c2f50dc6 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -63,26 +63,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const formatNodeList = useCallback((list: NodeTracing[]) => { const allItems = list.reverse() const result: NodeTracing[] = [] - let iterationIndex = 0 allItems.forEach((item) => { const { node_type, execution_metadata } = item if (node_type !== BlockEnum.Iteration) { const isInIteration = !!execution_metadata?.iteration_id if (isInIteration) { - const iterationDetails = result[result.length - 1].details! - const currentIterationIndex = execution_metadata?.iteration_index - const isIterationFirstNode = iterationIndex !== currentIterationIndex || iterationDetails.length === 0 - - if (isIterationFirstNode) { - iterationDetails!.push([item]) - iterationIndex = currentIterationIndex! - } - - else { - iterationDetails[iterationDetails.length - 1].push(item) + const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) + const iterationDetails = iterationNode?.details + const currentIterationIndex = execution_metadata?.iteration_index ?? 0 + + if (Array.isArray(iterationDetails)) { + if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) + iterationDetails[currentIterationIndex] = [item] + else + iterationDetails[currentIterationIndex].push(item) } - return } // not in iteration @@ -90,7 +86,6 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe return } - result.push({ ...item, details: [], diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index c833ea0342e6a1..4fc30f03df4ce6 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -1,10 +1,14 @@ 'use client' import type { FC } from 'react' -import React, { useCallback } from 'react' +import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { RiCloseLine } from '@remixicon/react' +import { + RiArrowRightSLine, + RiCloseLine, +} from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' -import NodePanel from './node' +import TracingPanel from './tracing-panel' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import type { NodeTracing } from '@/types/workflow' const i18nPrefix = 'workflow.singleRun' @@ -23,43 +27,67 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() + const [expandedIterations, setExpandedIterations] = useState>([]) + + const toggleIteration = useCallback((index: number) => { + setExpandedIterations(prev => ({ + ...prev, + [index]: !prev[index], + })) + }, []) const main = ( <> -
+
-
+
{t(`${i18nPrefix}.testRunIteration`)}
- +
-
+
-
{t(`${i18nPrefix}.back`)}
+
{t(`${i18nPrefix}.back`)}
{/* List */} -
+
{list.map((iteration, index) => ( -
-
-
{t(`${i18nPrefix}.iteration`)} {index + 1}
-
+
+
toggleIteration(index)} + > +
+
+ +
+ + {t(`${i18nPrefix}.iteration`)} {index + 1} + + +
-
- {iteration.map(node => ( - - ))} + {expandedIterations[index] &&
} +
+
))} diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index 66f996f13b4862..2e45290ddf2888 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -4,15 +4,17 @@ import type { FC } from 'react' import { useCallback, useEffect, useState } from 'react' import { RiArrowRightSLine, - RiCheckboxCircleLine, + RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' import Split from '../nodes/_base/components/split' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import type { NodeTracing } from '@/types/workflow' @@ -61,31 +63,38 @@ const NodePanel: FC = ({ return `${parseFloat((tokens / 1000000).toFixed(3))}M` } + const getCount = (iteration_curr_length: number | undefined, iteration_length: number) => { + if ((iteration_curr_length && iteration_curr_length < iteration_length) || !iteration_length) + return iteration_curr_length + + return iteration_length + } + useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration - const handleOnShowIterationDetail = (e: React.MouseEvent) => { + const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || []) } return ( -
-
+
+
setCollapseState(!collapseState)} > {!hideProcessDetail && ( @@ -93,23 +102,23 @@ const NodePanel: FC = ({
{nodeInfo.title}
{nodeInfo.status !== 'running' && !hideInfo && ( -
{`${getTime(nodeInfo.elapsed_time || 0)} · ${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens`}
+
{nodeInfo.execution_metadata?.total_tokens ? `${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens · ` : ''}{`${getTime(nodeInfo.elapsed_time || 0)}`}
)} {nodeInfo.status === 'succeeded' && ( - + )} {nodeInfo.status === 'failed' && ( - + )} {nodeInfo.status === 'stopped' && ( )} {nodeInfo.status === 'running' && ( -
+
Running
@@ -120,25 +129,27 @@ const NodePanel: FC = ({ {/* The nav to the iteration detail */} {isIterationNode && !notShowIterationNav && (
-
-
{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length || nodeInfo.details?.length })}
+
+
)} -
+
{nodeInfo.status === 'stopped' && (
{t('workflow.tracing.stopBy', { user: nodeInfo.created_by ? nodeInfo.created_by.name : 'N/A' })}
)} diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index c80bcb41db2785..7684dc84b65830 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -1,24 +1,266 @@ 'use client' import type { FC } from 'react' +import +React, +{ + useCallback, + useState, +} from 'react' +import cn from 'classnames' +import { + RiArrowDownSLine, + RiMenu4Line, +} from '@remixicon/react' import NodePanel from './node' +import { + BlockEnum, +} from '@/app/components/workflow/types' import type { NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] - onShowIterationDetail: (detail: NodeTracing[][]) => void + onShowIterationDetail?: (detail: NodeTracing[][]) => void + className?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean } -const TracingPanel: FC = ({ list, onShowIterationDetail }) => { +type TracingNodeProps = { + id: string + uniqueId: string + isParallel: boolean + data: NodeTracing | null + children: TracingNodeProps[] + parallelTitle?: string + branchTitle?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean +} + +function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { + const rootNodes: TracingNodeProps[] = [] + const parallelStacks: { [key: string]: TracingNodeProps } = {} + const levelCounts: { [key: string]: number } = {} + const parallelChildCounts: { [key: string]: Set } = {} + let uniqueIdCounter = 0 + const getUniqueId = () => { + uniqueIdCounter++ + return `unique-${uniqueIdCounter}` + } + + const getParallelTitle = (parentId: string | null): string => { + const levelKey = parentId || 'root' + if (!levelCounts[levelKey]) + levelCounts[levelKey] = 0 + + levelCounts[levelKey]++ + + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + return `PARALLEL-${levelNumber}${letter}` + } + + const getBranchTitle = (parentId: string | null, branchNum: number): string => { + const levelKey = parentId || 'root' + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + const branchLetter = String.fromCharCode(64 + branchNum) + return `BRANCH-${levelNumber}${letter}-${branchLetter}` + } + + // Count parallel children (for figuring out if we need to use letters) + for (const node of nodes) { + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + + if (parallel_id) { + const parentKey = parent_parallel_id || 'root' + if (!parallelChildCounts[parentKey]) + parallelChildCounts[parentKey] = new Set() + + parallelChildCounts[parentKey].add(parallel_id) + } + } + + for (const node of nodes) { + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_start_node_id = node.parallel_start_node_id ?? node.execution_metadata?.parallel_start_node_id ?? null + const parent_parallel_start_node_id = node.parent_parallel_start_node_id ?? node.execution_metadata?.parent_parallel_start_node_id ?? null + + if (!parallel_id || node.node_type === BlockEnum.End) { + rootNodes.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + }) + } + else { + if (!parallelStacks[parallel_id]) { + const newParallelGroup: TracingNodeProps = { + id: parallel_id, + uniqueId: getUniqueId(), + isParallel: true, + data: null, + children: [], + parallelTitle: '', + } + parallelStacks[parallel_id] = newParallelGroup + + if (parent_parallel_id && parallelStacks[parent_parallel_id]) { + const sameBranchIndex = parallelStacks[parent_parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parent_parallel_start_node_id || c.data?.parallel_start_node_id === parent_parallel_start_node_id, + ) + parallelStacks[parent_parallel_id].children.splice(sameBranchIndex + 1, 0, newParallelGroup) + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + } + else { + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + rootNodes.push(newParallelGroup) + } + } + const branchTitle = parallel_start_node_id === node.node_id ? getBranchTitle(parent_parallel_id, parallelStacks[parallel_id].children.length + 1) : '' + if (branchTitle) { + parallelStacks[parallel_id].children.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + else { + let sameBranchIndex = parallelStacks[parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parallel_start_node_id || c.data?.parallel_start_node_id === parallel_start_node_id, + ) + if (parallelStacks[parallel_id].children[sameBranchIndex + 1]?.isParallel) + sameBranchIndex++ + + parallelStacks[parallel_id].children.splice(sameBranchIndex + 1, 0, { + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + } + } + + return rootNodes +} + +const TracingPanel: FC = ({ + list, + onShowIterationDetail, + className, + hideNodeInfo = false, + hideNodeProcessDetail = false, +}) => { + const treeNodes = buildLogTree(list) + const [collapsedNodes, setCollapsedNodes] = useState>(new Set()) + const [hoveredParallel, setHoveredParallel] = useState(null) + + const toggleCollapse = (id: string) => { + setCollapsedNodes((prev) => { + const newSet = new Set(prev) + if (newSet.has(id)) + newSet.delete(id) + + else + newSet.add(id) + + return newSet + }) + } + + const handleParallelMouseEnter = useCallback((id: string) => { + setHoveredParallel(id) + }, []) + + const handleParallelMouseLeave = useCallback((e: React.MouseEvent) => { + const relatedTarget = e.relatedTarget as Element | null + if (relatedTarget && 'closest' in relatedTarget) { + const closestParallel = relatedTarget.closest('[data-parallel-id]') + if (closestParallel) + setHoveredParallel(closestParallel.getAttribute('data-parallel-id')) + + else + setHoveredParallel(null) + } + else { + setHoveredParallel(null) + } + }, []) + + const renderNode = (node: TracingNodeProps) => { + if (node.isParallel) { + const isCollapsed = collapsedNodes.has(node.id) + const isHovered = hoveredParallel === node.id + return ( +
handleParallelMouseEnter(node.id)} + onMouseLeave={handleParallelMouseLeave} + > +
+ +
+ {node.parallelTitle} +
+
+
+
+
+ {node.children.map(renderNode)} +
+
+ ) + } + else { + const isHovered = hoveredParallel === node.id + return ( +
+
+ {node.branchTitle} +
+ +
+ ) + } + } + return ( -
- {list.map(node => ( - - ))} +
+ {treeNodes.map(renderNode)}
) } diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 2e5e7741913f84..853d0c5934759d 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -162,6 +162,8 @@ type Shape = { setControlPromptEditorRerenderKey: (controlPromptEditorRerenderKey: number) => void showImportDSLModal: boolean setShowImportDSLModal: (showImportDSLModal: boolean) => void + showTips: string + setShowTips: (showTips: string) => void } export const createWorkflowStore = () => { @@ -262,6 +264,8 @@ export const createWorkflowStore = () => { setControlPromptEditorRerenderKey: controlPromptEditorRerenderKey => set(() => ({ controlPromptEditorRerenderKey })), showImportDSLModal: false, setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), + showTips: '', + setShowTips: showTips => set(() => ({ showTips })), })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 12957aa0dd01c7..797c2dbd855847 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -26,6 +26,7 @@ export enum BlockEnum { Tool = 'tool', ParameterExtractor = 'parameter-extractor', Iteration = 'iteration', + IterationStart = 'iteration-start', Assigner = 'assigner', // is now named as VariableAssigner } @@ -54,7 +55,7 @@ export type CommonNodeType = { _holdAddVariablePopup?: boolean _iterationLength?: number _iterationIndex?: number - isIterationStart?: boolean + _inParallelHovering?: boolean isInIteration?: boolean iteration_id?: string selected?: boolean diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index e73485f5462caf..91656e3bbcb17c 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -1,12 +1,15 @@ import { Position, getConnectedEdges, + getIncomers, getOutgoers, } from 'reactflow' import dagre from '@dagrejs/dagre' import { v4 as uuid4 } from 'uuid' import { cloneDeep, + groupBy, + isEqual, uniqBy, } from 'lodash-es' import type { @@ -19,14 +22,17 @@ import type { import { BlockEnum } from './types' import { CUSTOM_NODE, + ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION, } from './constants' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import type { QuestionClassifierNodeType } from './nodes/question-classifier/types' import type { IfElseNodeType } from './nodes/if-else/types' import { branchNameCorrect } from './nodes/if-else/utils' import type { ToolNodeType } from './nodes/tool/types' +import type { IterationNodeType } from './nodes/iteration/types' import { CollectionType } from '@/app/components/tools/types' import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' @@ -84,9 +90,130 @@ const getCycleEdges = (nodes: Node[], edges: Edge[]) => { return cycleEdges } +export function getIterationStartNode(iterationId: string): Node { + return generateNewNode({ + id: `${iterationId}start`, + type: CUSTOM_ITERATION_START_NODE, + data: { + title: '', + desc: '', + type: BlockEnum.IterationStart, + isInIteration: true, + }, + position: { + x: 24, + y: 68, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + parentId: iterationId, + selectable: false, + draggable: false, + }).newNode +} + +export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }): { + newNode: Node + newIterationStartNode?: Node +} { + const newNode = { + id: id || `${Date.now()}`, + type: type || CUSTOM_NODE, + data, + position, + targetPosition: Position.Left, + sourcePosition: Position.Right, + zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, + ...rest, + } as Node + + if (data.type === BlockEnum.Iteration) { + const newIterationStartNode = getIterationStartNode(newNode.id); + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id; + (newNode.data as IterationNodeType)._children = [newIterationStartNode.id] + return { + newNode, + newIterationStartNode, + } + } + + return { + newNode, + } +} + +export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { + const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration) + + if (!hasIterationNode) { + return { + nodes, + edges, + } + } + const nodesMap = nodes.reduce((prev, next) => { + prev[next.id] = next + return prev + }, {} as Record) + const iterationNodesWithStartNode = [] + const iterationNodesWithoutStartNode = [] + + for (let i = 0; i < nodes.length; i++) { + const currentNode = nodes[i] as Node + + if (currentNode.data.type === BlockEnum.Iteration) { + if (currentNode.data.start_node_id) { + if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE) + iterationNodesWithStartNode.push(currentNode) + } + else { + iterationNodesWithoutStartNode.push(currentNode) + } + } + } + const newIterationStartNodesMap = {} as Record + const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => { + const newNode = getIterationStartNode(iterationNode.id) + newNode.id = newNode.id + index + newIterationStartNodesMap[iterationNode.id] = newNode + return newNode + }) + const newEdges = iterationNodesWithStartNode.map((iterationNode) => { + const newNode = newIterationStartNodesMap[iterationNode.id] + const startNode = nodesMap[iterationNode.data.start_node_id] + const source = newNode.id + const sourceHandle = 'source' + const target = startNode.id + const targetHandle = 'target' + return { + id: `${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'custom', + source, + sourceHandle, + target, + targetHandle, + data: { + sourceType: newNode.data.type, + targetType: startNode.data.type, + isInIteration: true, + iteration_id: startNode.parentId, + _connectedNodeIsSelected: true, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + } + }) + nodes.forEach((node) => { + if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id]) + (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id + }) + + return { + nodes: [...nodes, ...newIterationStartNodes], + edges: [...edges, ...newEdges], + } +} + export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) const firstNode = nodes[0] if (!firstNode?.position) { @@ -148,8 +275,7 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { } export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) let selectedNode: Node | null = null const nodesMap = nodes.reduce((acc, node) => { acc[node.id] = node @@ -291,19 +417,6 @@ export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSo return nodesConnectedSourceOrTargetHandleIdsMap } -export const generateNewNode = ({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }) => { - return { - id: id || `${Date.now()}`, - type: type || CUSTOM_NODE, - data, - position, - targetPosition: Position.Left, - sourcePosition: Position.Right, - zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, - ...rest, - } as Node -} - export const genNewNodeTitleFromOld = (oldTitle: string) => { const regex = /^(.+?)\s*\((\d+)\)\s*$/ const match = oldTitle.match(regex) @@ -479,3 +592,167 @@ export const variableTransformer = (v: ValueSelector | string) => { return `{{#${v.join('.')}#}}` } + +type ParallelInfoItem = { + parallelNodeId: string + depth: number + isBranch?: boolean +} +type NodeParallelInfo = { + parallelNodeId: string + edgeHandleId: string + depth: number +} +type NodeHandle = { + node: Node + handle: string +} +type NodeStreamInfo = { + upstreamNodes: Set + downstreamEdges: Set +} +export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { + let startNode + + if (parentNodeId) { + const parentNode = nodes.find(node => node.id === parentNodeId) + if (!parentNode) + throw new Error('Parent node not found') + + startNode = nodes.find(node => node.id === (parentNode.data as IterationNodeType).start_node_id) + } + else { + startNode = nodes.find(node => node.data.type === BlockEnum.Start) + } + if (!startNode) + throw new Error('Start node not found') + + const parallelList = [] as ParallelInfoItem[] + const nextNodeHandles = [{ node: startNode, handle: 'source' }] + let hasAbnormalEdges = false + + const traverse = (firstNodeHandle: NodeHandle) => { + const nodeEdgesSet = {} as Record> + const totalEdgesSet = new Set() + const nextHandles = [firstNodeHandle] + const streamInfo = {} as Record + const parallelListItem = { + parallelNodeId: '', + depth: 0, + } as ParallelInfoItem + const nodeParallelInfoMap = {} as Record + nodeParallelInfoMap[firstNodeHandle.node.id] = { + parallelNodeId: '', + edgeHandleId: '', + depth: 0, + } + + while (nextHandles.length) { + const currentNodeHandle = nextHandles.shift()! + const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle + const currentNodeHandleKey = currentNode.id + const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) + const connectedEdgesLength = connectedEdges.length + const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) + const incomers = getIncomers(currentNode, nodes, edges) + + if (!streamInfo[currentNodeHandleKey]) { + streamInfo[currentNodeHandleKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { + const newSet = new Set() + for (const item of totalEdgesSet) { + if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) + newSet.add(item) + } + if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + nextNodeHandles.push({ node: currentNode, handle: currentHandle }) + break + } + } + + if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + + outgoers.forEach((outgoer) => { + const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) + const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') + const incomers = getIncomers(outgoer, nodes, edges) + + if (outgoers.length > 1 && incomers.length > 1) + hasAbnormalEdges = true + + Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { + nextHandles.push({ node: outgoer, handle: sourceHandle }) + }) + if (!outgoerConnectedEdges.length) + nextHandles.push({ node: outgoer, handle: 'source' }) + + const outgoerKey = outgoer.id + if (!nodeEdgesSet[outgoerKey]) + nodeEdgesSet[outgoerKey] = new Set() + + if (nodeEdgesSet[currentNodeHandleKey]) { + for (const item of nodeEdgesSet[currentNodeHandleKey]) + nodeEdgesSet[outgoerKey].add(item) + } + + if (!streamInfo[outgoerKey]) { + streamInfo[outgoerKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (!nodeParallelInfoMap[outgoer.id]) { + nodeParallelInfoMap[outgoer.id] = { + ...nodeParallelInfoMap[currentNode.id], + } + } + + if (connectedEdgesLength > 1) { + const edge = connectedEdges.find(edge => edge.target === outgoer.id)! + nodeEdgesSet[outgoerKey].add(edge.id) + totalEdgesSet.add(edge.id) + + streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) + streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) + + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[item].downstreamEdges.add(edge.id) + + if (!parallelListItem.parallelNodeId) + parallelListItem.parallelNodeId = currentNode.id + + const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 + const currentDepth = nodeParallelInfoMap[outgoer.id].depth + + nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) + } + else { + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[outgoerKey].upstreamNodes.add(item) + + nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth + } + }) + } + + parallelList.push(parallelListItem) + } + + while (nextNodeHandles.length) { + const nodeHandle = nextNodeHandles.shift()! + traverse(nodeHandle) + } + + return { + parallelList, + hasAbnormalEdges, + } +} diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 0d924ba16ce005..b83d213cb851d7 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Overwrite and Import', importFailure: 'Import failure', importSuccess: 'Import success', + parallelRun: 'Parallel Run', + parallelTip: { + click: { + title: 'Click', + desc: ' to add', + }, + drag: { + title: 'Drag', + desc: ' to connect', + }, + limit: 'Parallelism is limited to {{num}} branches.', + depthLimit: 'Parallel nesting layer limit of {{num}} layers', + }, + disconnect: 'Disconnect', + jumpToNode: 'Jump to this node', + addParallelNode: 'Add Parallel Node', }, env: { envPanelTitle: 'Environment Variables', @@ -412,7 +428,6 @@ const translation = { 'not empty': 'is not empty', 'null': 'is null', 'not null': 'is not null', - 'regex match': 'regex match', }, enterValue: 'Enter value', addCondition: 'Add Condition', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 0f00b117c12dde..39311649f44f69 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: '覆盖并导入', importFailure: '导入失败', importSuccess: '导入成功', + parallelRun: '并行运行', + parallelTip: { + click: { + title: '点击', + desc: '添加节点', + }, + drag: { + title: '拖拽', + desc: '连接节点', + }, + limit: '并行分支限制为 {{num}} 个', + depthLimit: '并行嵌套层数限制 {{num}} 层', + }, + disconnect: '断开连接', + jumpToNode: '跳转到节点', + addParallelNode: '添加并行节点', }, env: { envPanelTitle: '环境变量', @@ -412,7 +428,6 @@ const translation = { 'not empty': '不为空', 'null': '空', 'not null': '不为空', - 'regex match': '正则匹配', }, enterValue: '输入值', addCondition: '添加条件', diff --git a/web/package.json b/web/package.json index b6275f20aeb738..374286f8f717c6 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.7.3", + "version": "0.8.0", "private": true, "engines": { "node": ">=18.17.0" diff --git a/web/service/base.ts b/web/service/base.ts index 8f1b22bc1b364e..83389d8be80fa3 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -8,6 +8,8 @@ import type { IterationStartedResponse, NodeFinishedResponse, NodeStartedResponse, + ParallelBranchFinishedResponse, + ParallelBranchStartedResponse, TextChunkResponse, TextReplaceResponse, WorkflowFinishedResponse, @@ -59,6 +61,8 @@ export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void +export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void +export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void export type IOnTextChunk = (textChunk: TextChunkResponse) => void export type IOnTTSChunk = (messageId: string, audioStr: string, audioType?: string) => void export type IOnTTSEnd = (messageId: string, audioStr: string, audioType?: string) => void @@ -86,6 +90,8 @@ export type IOtherOptions = { onIterationStart?: IOnIterationStarted onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onParallelBranchStarted?: IOnParallelBranchStarted + onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk onTTSChunk?: IOnTTSChunk onTTSEnd?: IOnTTSEnd @@ -139,6 +145,8 @@ const handleStream = ( onIterationStart?: IOnIterationStarted, onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onParallelBranchStarted?: IOnParallelBranchStarted, + onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, onTTSChunk?: IOnTTSChunk, onTTSEnd?: IOnTTSEnd, @@ -228,6 +236,12 @@ const handleStream = ( else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'parallel_branch_started') { + onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) + } + else if (bufferObj.event === 'parallel_branch_finished') { + onParallelBranchFinished?.(bufferObj as ParallelBranchFinishedResponse) + } else if (bufferObj.event === 'text_chunk') { onTextChunk?.(bufferObj as TextChunkResponse) } @@ -488,6 +502,8 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onParallelBranchStarted, + onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, @@ -544,7 +560,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/themes/dark.css b/web/themes/dark.css index 8aab0f5fbb5221..8d77329b5a4948 100644 --- a/web/themes/dark.css +++ b/web/themes/dark.css @@ -316,6 +316,7 @@ html[data-theme="dark"] { --color-workflow-block-border: #FFFFFF14; --color-workflow-block-parma-bg: #FFFFFF0D; --color-workflow-block-bg: #27272B; + --color-workflow-block-border-highlight: #C8CEDA33; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #1D1D20; diff --git a/web/themes/light.css b/web/themes/light.css index 5a5ef637692d72..89303c250ea3cb 100644 --- a/web/themes/light.css +++ b/web/themes/light.css @@ -316,6 +316,7 @@ html[data-theme="light"] { --color-workflow-block-border: #FFFFFF; --color-workflow-block-parma-bg: #F2F4F7; --color-workflow-block-bg: #FCFCFD; + --color-workflow-block-border-highlight: #155AEF24; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #F2F4F7; diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index 9178d23f248fb3..643c96d1a11d9b 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -316,6 +316,7 @@ const vars = { 'workflow-block-border': 'var(--color-workflow-block-border)', 'workflow-block-parma-bg': 'var(--color-workflow-block-parma-bg)', 'workflow-block-bg': 'var(--color-workflow-block-bg)', + 'workflow-block-border-highlight': 'var(--color-workflow-block-border-highlight)', 'workflow-canvas-workflow-dot-color': 'var(--color-workflow-canvas-workflow-dot-color)', 'workflow-canvas-workflow-bg': 'var(--color-workflow-canvas-workflow-bg)', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index e6b001bd798fd7..dbf2b3e587b992 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -26,9 +26,14 @@ export type NodeTracing = { currency: string iteration_id?: string iteration_index?: number + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } metadata: { iterator_length: number + iterator_index: number } created_at: number created_by: { @@ -40,6 +45,10 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } export type FetchWorkflowDraftResponse = { @@ -109,6 +118,7 @@ export type NodeStartedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -125,6 +135,7 @@ export type NodeFinishedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -138,6 +149,10 @@ export type NodeFinishedResponse = { total_tokens: number total_price: number currency: string + parallel_id?: string + parallel_start_node_id?: string + iteration_index?: number + iteration_id?: string } created_at: number } @@ -152,6 +167,8 @@ export type IterationStartedResponse = { node_id: string metadata: { iterator_length: number + iteration_id: string + iteration_index: number } created_at: number extras?: any @@ -169,6 +186,9 @@ export type IterationNextResponse = { output: any extras?: any created_at: number + execution_metadata: { + parallel_id?: string + } } } @@ -184,6 +204,39 @@ export type IterationFinishedResponse = { status: string created_at: number error: string + execution_metadata: { + parallel_id?: string + } + } +} + +export type ParallelBranchStartedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + created_at: number + } +} + +export type ParallelBranchFinishedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + status: string + created_at: number + error: string } } From 178730266dedfdfd36db0762d687f80cd6b0e5c1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:13:26 +0800 Subject: [PATCH 06/13] chore: translate i18n files (#8202) Co-authored-by: takatost <5485478+takatost@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- web/i18n/de-DE/workflow.ts | 16 ++++++++++++++++ web/i18n/es-ES/workflow.ts | 16 ++++++++++++++++ web/i18n/fa-IR/workflow.ts | 16 ++++++++++++++++ web/i18n/fr-FR/workflow.ts | 16 ++++++++++++++++ web/i18n/hi-IN/workflow.ts | 16 ++++++++++++++++ web/i18n/it-IT/workflow.ts | 16 ++++++++++++++++ web/i18n/ja-JP/workflow.ts | 16 ++++++++++++++++ web/i18n/ko-KR/workflow.ts | 16 ++++++++++++++++ web/i18n/pl-PL/workflow.ts | 16 ++++++++++++++++ web/i18n/pt-BR/workflow.ts | 16 ++++++++++++++++ web/i18n/ro-RO/workflow.ts | 16 ++++++++++++++++ web/i18n/ru-RU/workflow.ts | 16 ++++++++++++++++ web/i18n/tr-TR/workflow.ts | 16 ++++++++++++++++ web/i18n/uk-UA/workflow.ts | 16 ++++++++++++++++ web/i18n/vi-VN/workflow.ts | 16 ++++++++++++++++ web/i18n/zh-Hant/workflow.ts | 16 ++++++++++++++++ 16 files changed, 256 insertions(+) diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 2762c01a8dd410..5e154aeca52a0c 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -77,6 +77,22 @@ const translation = { importDSLTip: 'Der aktuelle Entwurf wird überschrieben. Exportieren Sie den Workflow vor dem Import als Backup.', overwriteAndImport: 'Überschreiben und Importieren', backupCurrentDraft: 'Aktuellen Entwurf sichern', + parallelTip: { + click: { + title: 'Klicken', + desc: 'hinzuzufügen', + }, + drag: { + title: 'Ziehen', + desc: 'um eine Verbindung herzustellen', + }, + limit: 'Die Parallelität ist auf {{num}} Zweige beschränkt.', + depthLimit: 'Begrenzung der parallelen Verschachtelungsschicht von {{num}} Schichten', + }, + parallelRun: 'Paralleler Lauf', + disconnect: 'Trennen', + jumpToNode: 'Zu diesem Knoten springen', + addParallelNode: 'Parallelen Knoten hinzufügen', }, env: { envPanelTitle: 'Umgebungsvariablen', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 7db2b48687eddb..0efc996f91ab95 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Sobrescribir e importar', importFailure: 'Error al importar', importSuccess: 'Importación exitosa', + parallelTip: { + click: { + title: 'Clic', + desc: 'Para agregar', + }, + drag: { + title: 'Arrastrar', + desc: 'Para conectarse', + }, + limit: 'El paralelismo se limita a {{num}} ramas.', + depthLimit: 'Límite de capa de anidamiento paralelo de capas {{num}}', + }, + parallelRun: 'Ejecución paralela', + disconnect: 'Desconectar', + jumpToNode: 'Saltar a este nodo', + addParallelNode: 'Agregar nodo paralelo', }, env: { envPanelTitle: 'Variables de Entorno', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index da2d439a36acee..67020f50250d70 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'بازنویسی و وارد کردن', importFailure: 'خطا در وارد کردن', importSuccess: 'وارد کردن موفقیت‌آمیز', + parallelTip: { + click: { + title: 'کلیک کنید', + desc: 'اضافه کردن', + }, + drag: { + desc: 'برای اتصال', + title: 'کشیدن', + }, + depthLimit: 'حد لایه تودرتو موازی لایه های {{num}}', + limit: 'موازی سازی به شاخه های {{num}} محدود می شود.', + }, + disconnect: 'قطع', + jumpToNode: 'پرش به این گره', + parallelRun: 'اجرای موازی', + addParallelNode: 'افزودن گره موازی', }, env: { envPanelTitle: 'متغیرهای محیطی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index 8691f6f0d51cd3..3e56246e0c2013 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Écraser et importer', importFailure: 'Echec de l\'importation', importSuccess: 'Import avec succès', + parallelTip: { + click: { + title: 'Cliquer', + desc: 'à ajouter', + }, + drag: { + title: 'Traîner', + desc: 'pour se connecter', + }, + limit: 'Le parallélisme est limité aux branches {{num}}.', + depthLimit: 'Limite de couches d’imbrication parallèle de {{num}} couches', + }, + parallelRun: 'Exécution parallèle', + disconnect: 'Déconnecter', + jumpToNode: 'Aller à ce nœud', + addParallelNode: 'Ajouter un nœud parallèle', }, env: { envPanelTitle: 'Variables d\'Environnement', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 96c232d72ffd65..072e4874e35df3 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -80,6 +80,22 @@ const translation = { backupCurrentDraft: 'बैकअप वर्तमान ड्राफ्ट', importFailure: 'आयात विफलता', importDSLTip: 'वर्तमान ड्राफ्ट ओवरराइट हो जाएगा। आयात करने से पहले वर्कफ़्लो को बैकअप के रूप में निर्यात करें.', + parallelTip: { + click: { + title: 'क्लिक करना', + desc: 'जोड़ने के लिए', + }, + drag: { + title: 'खींचना', + desc: 'कनेक्ट करने के लिए', + }, + limit: 'समांतरता {{num}} शाखाओं तक सीमित है।', + depthLimit: '{{num}} परतों की समानांतर नेस्टिंग परत सीमा', + }, + disconnect: 'अलग करना', + parallelRun: 'समानांतर रन', + jumpToNode: 'इस नोड पर जाएं', + addParallelNode: 'समानांतर नोड जोड़ें', }, env: { envPanelTitle: 'पर्यावरण चर', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 62ce0c567776c5..f5d6fc8bf57c4c 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -81,6 +81,22 @@ const translation = { overwriteAndImport: 'Sovrascrivi e Importa', importFailure: 'Importazione fallita', importSuccess: 'Importazione riuscita', + parallelTip: { + click: { + title: 'Clic', + desc: 'per aggiungere', + }, + drag: { + title: 'Trascinare', + desc: 'per collegare', + }, + depthLimit: 'Limite di livelli di annidamento parallelo di {{num}} livelli', + limit: 'Il parallelismo è limitato ai rami {{num}}.', + }, + parallelRun: 'Corsa parallela', + disconnect: 'Disconnettere', + jumpToNode: 'Vai a questo nodo', + addParallelNode: 'Aggiungi nodo parallelo', }, env: { envPanelTitle: 'Variabili d\'Ambiente', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index d244e36bcd44c9..755061e8f65772 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'オーバライトとインポート', importFailure: 'インポート失敗', importSuccess: 'インポート成功', + parallelTip: { + click: { + title: 'クリック', + desc: '追加する', + }, + drag: { + title: 'ドラッグ', + desc: '接続するには', + }, + limit: '並列処理は {{num}} ブランチに制限されています。', + depthLimit: '{{num}}レイヤーの平行ネストレイヤーの制限', + }, + parallelRun: 'パラレルラン', + disconnect: '切る', + jumpToNode: 'このノードにジャンプします', + addParallelNode: '並列ノードを追加', }, env: { envPanelTitle: '環境変数', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index 7c85d02c3d65dc..8fed0e0417fc2f 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -77,6 +77,22 @@ const translation = { importSuccess: '가져오기 성공', syncingData: '단 몇 초 만에 데이터를 동기화할 수 있습니다.', importDSLTip: '현재 초안을 덮어씁니다. 가져오기 전에 워크플로를 백업으로 내보냅니다.', + parallelTip: { + click: { + title: '클릭', + desc: '추가', + }, + drag: { + title: '드래그', + desc: '연결 방법', + }, + depthLimit: '평행 중첩 레이어 {{num}}개 레이어의 제한', + limit: '병렬 처리는 {{num}}개의 분기로 제한됩니다.', + }, + parallelRun: '병렬 실행', + disconnect: '분리하다', + jumpToNode: '이 노드로 이동', + addParallelNode: '병렬 노드 추가', }, env: { envPanelTitle: '환경 변수', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 09b9cea1be68f2..de05ee7169d3d7 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -77,6 +77,22 @@ const translation = { chooseDSL: 'Wybierz plik DSL(yml)', backupCurrentDraft: 'Utwórz kopię zapasową bieżącej wersji roboczej', importFailure: 'Niepowodzenie importu', + parallelTip: { + click: { + title: 'Klikać', + desc: ', aby dodać', + }, + drag: { + title: 'Przeciągnąć', + desc: 'aby się połączyć', + }, + limit: 'Równoległość jest ograniczona do gałęzi {{num}}.', + depthLimit: 'Limit warstw zagnieżdżania równoległego dla warstw {{num}}', + }, + parallelRun: 'Bieg równoległy', + jumpToNode: 'Przejdź do tego węzła', + disconnect: 'Odłączyć', + addParallelNode: 'Dodaj węzeł równoległy', }, env: { envPanelTitle: 'Zmienne Środowiskowe', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index da412683020c20..ef589c0bde2a4e 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -77,6 +77,22 @@ const translation = { importDSLTip: 'O rascunho atual será substituído. Exporte o fluxo de trabalho como backup antes de importar.', backupCurrentDraft: 'Fazer backup do rascunho atual', importDSL: 'Importar DSL', + parallelTip: { + click: { + title: 'Clique', + desc: 'para adicionar', + }, + drag: { + title: 'Arrastar', + desc: 'para conectar', + }, + limit: 'O paralelismo é limitado a {{num}} ramificações.', + depthLimit: 'Limite de camada de aninhamento paralelo de {{num}} camadas', + }, + parallelRun: 'Execução paralela', + disconnect: 'Desligar', + jumpToNode: 'Ir para este nó', + addParallelNode: 'Adicionar nó paralelo', }, env: { envPanelTitle: 'Variáveis de Ambiente', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index f8177aeb06a35f..689ebdead9d237 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -77,6 +77,22 @@ const translation = { importSuccess: 'Succesul importului', backupCurrentDraft: 'Backup curent draft', importDSLTip: 'Proiectul curent va fi suprascris. Exportați fluxul de lucru ca backup înainte de import.', + parallelTip: { + click: { + title: 'Clic', + desc: 'pentru a adăuga', + }, + drag: { + title: 'Glisa', + desc: 'pentru a vă conecta', + }, + depthLimit: 'Limita straturilor de imbricare paralelă a {{num}} straturi', + limit: 'Paralelismul este limitat la {{num}} ramuri.', + }, + parallelRun: 'Rulare paralelă', + disconnect: 'Deconecta', + jumpToNode: 'Sari la acest nod', + addParallelNode: 'Adăugare nod paralel', }, env: { envPanelTitle: 'Variabile de Mediu', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index b31b798bd3a033..9d3ce1235cc512 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Перезаписать и импортировать', importFailure: 'Ошибка импорта', importSuccess: 'Импорт успешно завершен', + parallelTip: { + click: { + title: 'Щелчок', + desc: 'добавить', + }, + drag: { + title: 'Волочить', + desc: 'для подключения', + }, + limit: 'Параллелизм ограничен ветвями {{num}}.', + depthLimit: 'Ограничение на количество слоев параллельной вложенности {{num}}', + }, + parallelRun: 'Параллельный прогон', + disconnect: 'Разъединять', + jumpToNode: 'Перейти к этому узлу', + addParallelNode: 'Добавить параллельный узел', }, env: { envPanelTitle: 'Переменные среды', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 856c5d518f7f3a..96313d6d6b6fdd 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Üzerine Yaz ve İçe Aktar', importFailure: 'İçe Aktarma Başarısız', importSuccess: 'İçe Aktarma Başarılı', + parallelTip: { + click: { + desc: 'Eklemek için', + title: 'Tık', + }, + drag: { + title: 'Sürükleme', + desc: 'Bağlanmak için', + }, + depthLimit: '{{num}} katmanlarının paralel iç içe geçme katmanı sınırı', + limit: 'Paralellik {{num}} dallarıyla sınırlıdır.', + }, + jumpToNode: 'Bu düğüme atla', + addParallelNode: 'Paralel Düğüm Ekle', + disconnect: 'Ayırmak', + parallelRun: 'Paralel Koşu', }, env: { envPanelTitle: 'Çevre Değişkenleri', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index c6b98d1b5b37db..03471348c8a047 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -77,6 +77,22 @@ const translation = { chooseDSL: 'Виберіть файл DSL(yml)', backupCurrentDraft: 'Резервна поточна чернетка', importDSLTip: 'Поточна чернетка буде перезаписана. Експортуйте робочий процес як резервну копію перед імпортом.', + parallelTip: { + click: { + title: 'Натисніть', + desc: 'щоб додати', + }, + drag: { + title: 'Перетягувати', + desc: 'Щоб підключити', + }, + limit: 'Паралелізм обмежується {{num}} гілками.', + depthLimit: 'Обмеження рівня паралельного вкладеності шарів {{num}}', + }, + disconnect: 'Відключити', + parallelRun: 'Паралельний біг', + jumpToNode: 'Перейти до цього вузла', + addParallelNode: 'Додати паралельний вузол', }, env: { envPanelTitle: 'Змінні середовища', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index d850d1a7327c9d..5be19ab7fd18ee 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Ghi đè và nhập', importDSL: 'Nhập DSL', syncingData: 'Đồng bộ hóa dữ liệu, chỉ vài giây.', + parallelTip: { + click: { + title: 'Bấm', + desc: 'để thêm', + }, + drag: { + title: 'Kéo', + desc: 'Để kết nối', + }, + limit: 'Song song được giới hạn trong các nhánh {{num}}.', + depthLimit: 'Giới hạn lớp lồng song song của {{num}} layer', + }, + parallelRun: 'Chạy song song', + disconnect: 'Ngắt kết nối', + jumpToNode: 'Chuyển đến nút này', + addParallelNode: 'Thêm nút song song', }, env: { envPanelTitle: 'Biến Môi Trường', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 618fa8ce2eedc5..eef3ffaebd9be3 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -77,6 +77,22 @@ const translation = { syncingData: '同步數據,只需幾秒鐘。', importDSLTip: '當前草稿將被覆蓋。在導入之前將工作流匯出為備份。', importFailure: '匯入失敗', + parallelTip: { + click: { + title: '點擊', + desc: '添加', + }, + drag: { + title: '拖动', + desc: '連接', + }, + limit: '並行度僅限於 {{num}} 個分支。', + depthLimit: '並行嵌套層限制為 {{num}} 個層', + }, + parallelRun: '並行運行', + disconnect: '斷開', + jumpToNode: '跳轉到此節點', + addParallelNode: '添加並行節點', }, env: { envPanelTitle: '環境變數', From 2cf1187b32c0885e9f9b987d57c6ca9da4508a2c Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 10 Sep 2024 17:00:20 +0800 Subject: [PATCH 07/13] chore(api/core): apply ruff reformatting (#7624) --- api/core/__init__.py | 2 +- api/core/agent/cot_agent_runner.py | 215 ++- api/core/agent/cot_chat_agent_runner.py | 26 +- api/core/agent/cot_completion_agent_runner.py | 21 +- api/core/agent/entities.py | 16 +- api/core/agent/fc_agent_runner.py | 235 ++- .../agent/output_parser/cot_output_parser.py | 74 +- api/core/agent/prompt/template.py | 18 +- .../app/app_config/base_app_config_manager.py | 26 +- .../sensitive_word_avoidance/manager.py | 23 +- .../easy_ui_based_app/agent/manager.py | 63 +- .../easy_ui_based_app/dataset/manager.py | 88 +- .../model_config/converter.py | 30 +- .../easy_ui_based_app/model_config/manager.py | 33 +- .../prompt_template/manager.py | 57 +- .../easy_ui_based_app/variables/manager.py | 52 +- api/core/app/app_config/entities.py | 46 +- .../features/file_upload/manager.py | 28 +- .../features/more_like_this/manager.py | 8 +- .../features/opening_statement/manager.py | 6 +- .../features/retrieval_resource/manager.py | 8 +- .../features/speech_to_text/manager.py | 8 +- .../manager.py | 14 +- .../features/text_to_speech/manager.py | 16 +- .../apps/advanced_chat/app_config_manager.py | 26 +- .../app/apps/advanced_chat/app_generator.py | 166 +- .../app_generator_tts_publisher.py | 37 +- api/core/app/apps/advanced_chat/app_runner.py | 84 +- .../generate_response_converter.py | 54 +- .../advanced_chat/generate_task_pipeline.py | 167 +- .../app/apps/agent_chat/app_config_manager.py | 59 +- api/core/app/apps/agent_chat/app_generator.py | 102 +- api/core/app/apps/agent_chat/app_runner.py | 93 +- .../agent_chat/generate_response_converter.py | 56 +- .../base_app_generate_response_converter.py | 85 +- api/core/app/apps/base_app_generator.py | 12 +- api/core/app/apps/base_app_queue_manager.py | 30 +- api/core/app/apps/base_app_runner.py | 272 ++- api/core/app/apps/chat/app_config_manager.py | 45 +- api/core/app/apps/chat/app_generator.py | 106 +- api/core/app/apps/chat/app_runner.py | 51 +- .../apps/chat/generate_response_converter.py | 56 +- .../app/apps/completion/app_config_manager.py | 35 +- api/core/app/apps/completion/app_generator.py | 175 +- api/core/app/apps/completion/app_runner.py | 36 +- .../completion/generate_response_converter.py | 50 +- .../app/apps/message_based_app_generator.py | 104 +- .../apps/message_based_app_queue_manager.py | 21 +- .../app/apps/workflow/app_config_manager.py | 18 +- api/core/app/apps/workflow/app_generator.py | 129 +- .../app/apps/workflow/app_queue_manager.py | 24 +- api/core/app/apps/workflow/app_runner.py | 23 +- .../workflow/generate_response_converter.py | 22 +- .../apps/workflow/generate_task_pipeline.py | 133 +- api/core/app/apps/workflow_app_runner.py | 150 +- .../app/apps/workflow_logging_callback.py | 200 +-- api/core/app/entities/app_invoke_entities.py | 34 +- api/core/app/entities/queue_entities.py | 54 +- api/core/app/entities/task_entities.py | 40 +- .../annotation_reply/annotation_reply.py | 60 +- .../hosting_moderation/hosting_moderation.py | 10 +- .../app/features/rate_limiting/rate_limit.py | 23 +- api/core/app/segments/__init__.py | 42 +- api/core/app/segments/factory.py | 20 +- api/core/app/segments/parser.py | 4 +- api/core/app/segments/segment_group.py | 6 +- api/core/app/segments/segments.py | 24 +- api/core/app/segments/types.py | 20 +- api/core/app/segments/variables.py | 5 +- .../based_generate_task_pipeline.py | 42 +- .../easy_ui_based_generate_task_pipeline.py | 175 +- .../app/task_pipeline/message_cycle_manage.py | 75 +- .../task_pipeline/workflow_cycle_manage.py | 117 +- .../agent_tool_callback_handler.py | 38 +- .../index_tool_callback_handler.py | 62 +- .../workflow_tool_callback_handler.py | 2 +- api/core/embedding/cached_embedding.py | 50 +- api/core/entities/agent_entities.py | 8 +- api/core/entities/message_entities.py | 6 +- api/core/entities/model_entities.py | 8 +- api/core/entities/provider_configuration.py | 368 +++-- api/core/entities/provider_entities.py | 20 +- api/core/errors/error.py | 7 + .../api_based_extension_requestor.py | 25 +- api/core/extension/extensible.py | 48 +- api/core/extension/extension.py | 5 +- api/core/external_data_tool/api/api.py | 76 +- .../external_data_tool/external_data_fetch.py | 41 +- api/core/external_data_tool/factory.py | 6 +- api/core/file/file_obj.py | 59 +- api/core/file/message_file_parser.py | 100 +- api/core/file/tool_file_parser.py | 9 +- api/core/file/upload_file_parser.py | 12 +- .../helper/code_executor/code_executor.py | 72 +- .../code_executor/code_node_provider.py | 20 +- .../javascript/javascript_code_provider.py | 3 +- .../javascript/javascript_transformer.py | 3 +- .../code_executor/jinja2/jinja2_formatter.py | 6 +- .../jinja2/jinja2_transformer.py | 4 +- .../python3/python3_code_provider.py | 3 +- .../code_executor/template_transformer.py | 14 +- api/core/helper/encrypter.py | 7 +- api/core/helper/model_provider_cache.py | 2 +- api/core/helper/moderation.py | 21 +- api/core/helper/module_import_helper.py | 9 +- api/core/helper/position_helper.py | 22 +- api/core/helper/ssrf_proxy.py | 33 +- api/core/helper/tool_parameter_cache.py | 15 +- api/core/helper/tool_provider_cache.py | 5 +- api/core/hosting_configuration.py | 106 +- api/core/indexing_runner.py | 474 +++--- api/core/llm_generator/llm_generator.py | 99 +- .../output_parser/rule_config_generator.py | 20 +- .../suggested_questions_after_answer.py | 3 +- api/core/llm_generator/prompts.py | 20 +- api/core/memory/token_buffer_memory.py | 64 +- .../model_runtime/callbacks/base_callback.py | 72 +- .../callbacks/logging_callback.py | 118 +- .../model_runtime/entities/common_entities.py | 1 + api/core/model_runtime/entities/defaults.py | 200 +-- .../model_runtime/entities/llm_entities.py | 35 +- .../entities/message_entities.py | 31 +- .../model_runtime/entities/model_entities.py | 47 +- .../entities/provider_entities.py | 17 +- .../model_runtime/entities/rerank_entities.py | 2 + .../entities/text_embedding_entities.py | 3 +- api/core/model_runtime/errors/invoke.py | 6 + api/core/model_runtime/errors/validate.py | 1 + .../model_providers/__base/ai_model.py | 79 +- .../__base/large_language_model.py | 260 +-- .../model_providers/__base/model_provider.py | 24 +- .../__base/moderation_model.py | 10 +- .../model_providers/__base/rerank_model.py | 29 +- .../__base/speech2text_model.py | 11 +- .../model_providers/__base/text2img_model.py | 13 +- .../__base/text_embedding_model.py | 13 +- .../__base/tokenizers/gpt2_tokenzier.py | 11 +- .../model_providers/__base/tts_model.py | 38 +- .../model_providers/anthropic/anthropic.py | 7 +- .../model_providers/anthropic/llm/llm.py | 306 ++-- .../model_providers/azure_openai/_common.py | 28 +- .../model_providers/azure_openai/_constant.py | 905 +++++----- .../azure_openai/azure_openai.py | 1 - .../model_providers/azure_openai/llm/llm.py | 261 ++- .../azure_openai/speech2text/speech2text.py | 9 +- .../text_embedding/text_embedding.py | 86 +- .../model_providers/azure_openai/tts/tts.py | 44 +- .../model_providers/baichuan/baichuan.py | 8 +- .../baichuan/llm/baichuan_tokenizer.py | 8 +- .../baichuan/llm/baichuan_turbo.py | 5 +- .../baichuan/llm/baichuan_turbo_errors.py | 7 +- .../model_providers/baichuan/llm/llm.py | 93 +- .../baichuan/text_embedding/text_embedding.py | 111 +- .../model_providers/bedrock/bedrock.py | 10 +- .../model_providers/bedrock/llm/llm.py | 528 +++--- .../bedrock/text_embedding/text_embedding.py | 113 +- .../model_providers/chatglm/chatglm.py | 7 +- .../model_providers/chatglm/llm/llm.py | 219 +-- .../model_providers/cohere/cohere.py | 7 +- .../model_providers/cohere/llm/llm.py | 312 ++-- .../model_providers/cohere/rerank/rerank.py | 46 +- .../cohere/text_embedding/text_embedding.py | 86 +- .../model_providers/deepseek/deepseek.py | 9 +- .../model_providers/deepseek/llm/llm.py | 45 +- .../model_providers/fishaudio/__init__.py | 1 - .../model_providers/fishaudio/fishaudio.py | 8 +- .../model_providers/fishaudio/tts/tts.py | 34 +- .../model_providers/google/google.py | 7 +- .../model_providers/google/llm/llm.py | 206 ++- .../model_providers/groq/groq.py | 9 +- .../model_providers/groq/llm/llm.py | 21 +- .../huggingface_hub/_common.py | 8 +- .../huggingface_hub/huggingface_hub.py | 1 - .../huggingface_hub/llm/llm.py | 191 +-- .../text_embedding/text_embedding.py | 109 +- .../huggingface_tei/huggingface_tei.py | 1 - .../huggingface_tei/rerank/rerank.py | 32 +- .../huggingface_tei/tei_helper.py | 54 +- .../text_embedding/text_embedding.py | 44 +- .../model_providers/hunyuan/hunyuan.py | 9 +- .../model_providers/hunyuan/llm/llm.py | 180 +- .../hunyuan/text_embedding/text_embedding.py | 51 +- .../model_providers/jina/jina.py | 8 +- .../model_providers/jina/rerank/rerank.py | 59 +- .../jina/text_embedding/jina_tokenizer.py | 8 +- .../jina/text_embedding/text_embedding.py | 92 +- .../model_providers/leptonai/leptonai.py | 9 +- .../model_providers/leptonai/llm/llm.py | 36 +- .../model_providers/localai/llm/llm.py | 355 ++-- .../model_providers/localai/localai.py | 3 +- .../model_providers/localai/rerank/rerank.py | 66 +- .../localai/speech2text/speech2text.py | 38 +- .../localai/text_embedding/text_embedding.py | 94 +- .../minimax/llm/chat_completion.py | 154 +- .../minimax/llm/chat_completion_pro.py | 159 +- .../model_providers/minimax/llm/errors.py | 7 +- .../model_providers/minimax/llm/llm.py | 202 +-- .../model_providers/minimax/llm/types.py | 31 +- .../model_providers/minimax/minimax.py | 10 +- .../minimax/text_embedding/text_embedding.py | 83 +- .../model_providers/mistralai/llm/llm.py | 23 +- .../model_providers/mistralai/mistralai.py | 8 +- .../model_providers/moonshot/llm/llm.py | 169 +- .../model_providers/moonshot/moonshot.py | 8 +- .../model_providers/novita/llm/llm.py | 52 +- .../model_providers/novita/novita.py | 7 +- .../model_providers/nvidia/llm/llm.py | 212 ++- .../model_providers/nvidia/nvidia.py | 8 +- .../model_providers/nvidia/rerank/rerank.py | 23 +- .../nvidia/text_embedding/text_embedding.py | 81 +- .../model_providers/nvidia_nim/llm/llm.py | 1 + .../model_providers/nvidia_nim/nvidia_nim.py | 1 - .../model_providers/oci/llm/llm.py | 189 ++- .../model_runtime/model_providers/oci/oci.py | 10 +- .../oci/text_embedding/text_embedding.py | 90 +- .../model_providers/ollama/llm/llm.py | 100 +- .../model_providers/ollama/ollama.py | 1 - .../ollama/text_embedding/text_embedding.py | 79 +- .../model_providers/openai/_common.py | 6 +- .../model_providers/openai/llm/llm.py | 442 ++--- .../openai/moderation/moderation.py | 10 +- .../model_providers/openai/openai.py | 8 +- .../openai/speech2text/speech2text.py | 6 +- .../openai/text_embedding/text_embedding.py | 65 +- .../model_providers/openai/tts/tts.py | 47 +- .../openai_api_compatible/_common.py | 9 +- .../openai_api_compatible/llm/llm.py | 445 +++-- .../openai_api_compatible.py | 1 - .../speech2text/speech2text.py | 4 +- .../text_embedding/text_embedding.py | 116 +- .../model_providers/openllm/llm/llm.py | 194 +-- .../openllm/llm/openllm_generate.py | 121 +- .../openllm/llm/openllm_generate_errors.py | 7 +- .../openllm/text_embedding/text_embedding.py | 67 +- .../model_providers/openrouter/llm/llm.py | 46 +- .../model_providers/openrouter/openrouter.py | 10 +- .../model_providers/perfxcloud/llm/llm.py | 42 +- .../model_providers/perfxcloud/perfxcloud.py | 8 +- .../text_embedding/text_embedding.py | 124 +- .../model_providers/replicate/_common.py | 8 +- .../model_providers/replicate/llm/llm.py | 174 +- .../model_providers/replicate/replicate.py | 1 - .../text_embedding/text_embedding.py | 97 +- .../model_providers/sagemaker/llm/llm.py | 262 ++- .../sagemaker/rerank/rerank.py | 102 +- .../model_providers/sagemaker/sagemaker.py | 30 +- .../sagemaker/speech2text/speech2text.py | 79 +- .../text_embedding/text_embedding.py | 101 +- .../model_providers/sagemaker/tts/tts.py | 232 ++- .../model_providers/siliconflow/llm/llm.py | 20 +- .../siliconflow/rerank/rerank.py | 47 +- .../siliconflow/siliconflow.py | 8 +- .../siliconflow/speech2text/speech2text.py | 4 +- .../text_embedding/text_embedding.py | 13 +- .../model_providers/spark/llm/_client.py | 130 +- .../model_providers/spark/llm/llm.py | 117 +- .../model_providers/stepfun/llm/llm.py | 166 +- .../model_providers/stepfun/stepfun.py | 8 +- .../tencent/speech2text/flash_recognizer.py | 51 +- .../tencent/speech2text/speech2text.py | 14 +- .../model_providers/tencent/tencent.py | 7 +- .../model_providers/togetherai/llm/llm.py | 90 +- .../model_providers/togetherai/togetherai.py | 1 - .../model_providers/tongyi/_common.py | 4 +- .../model_providers/tongyi/llm/llm.py | 221 +-- .../tongyi/text_embedding/text_embedding.py | 23 +- .../model_providers/tongyi/tongyi.py | 7 +- .../model_providers/tongyi/tts/tts.py | 56 +- .../triton_inference_server/llm/llm.py | 273 +-- .../triton_inference_server.py | 1 + .../model_providers/upstage/_common.py | 9 +- .../model_providers/upstage/llm/llm.py | 262 +-- .../upstage/text_embedding/text_embedding.py | 63 +- .../model_providers/upstage/upstage.py | 11 +- .../model_providers/vertex_ai/llm/llm.py | 302 ++-- .../text_embedding/text_embedding.py | 52 +- .../model_providers/vertex_ai/vertex_ai.py | 7 +- .../model_providers/volcengine_maas/client.py | 129 +- .../volcengine_maas/legacy/client.py | 69 +- .../volcengine_maas/legacy/errors.py | 50 +- .../legacy/volc_sdk/__init__.py | 2 +- .../legacy/volc_sdk/base/auth.py | 109 +- .../legacy/volc_sdk/base/service.py | 51 +- .../legacy/volc_sdk/base/util.py | 18 +- .../volcengine_maas/legacy/volc_sdk/common.py | 20 +- .../volcengine_maas/legacy/volc_sdk/maas.py | 59 +- .../volcengine_maas/llm/llm.py | 231 +-- .../volcengine_maas/llm/models.py | 141 +- .../volcengine_maas/text_embedding/models.py | 10 +- .../text_embedding/text_embedding.py | 59 +- .../model_providers/wenxin/_common.py | 124 +- .../model_providers/wenxin/llm/ernie_bot.py | 188 ++- .../model_providers/wenxin/llm/llm.py | 230 ++- .../wenxin/text_embedding/text_embedding.py | 70 +- .../model_providers/wenxin/wenxin.py | 8 +- .../model_providers/wenxin/wenxin_errors.py | 23 +- .../model_providers/xinference/llm/llm.py | 508 +++--- .../xinference/rerank/rerank.py | 121 +- .../xinference/speech2text/speech2text.py | 69 +- .../text_embedding/text_embedding.py | 82 +- .../model_providers/xinference/tts/tts.py | 180 +- .../xinference/xinference_helper.py | 86 +- .../model_providers/yi/llm/llm.py | 46 +- .../model_runtime/model_providers/yi/yi.py | 8 +- .../model_providers/zhinao/llm/llm.py | 20 +- .../model_providers/zhinao/zhinao.py | 8 +- .../model_providers/zhipuai/_common.py | 5 +- .../model_providers/zhipuai/llm/llm.py | 269 ++- .../zhipuai/text_embedding/text_embedding.py | 33 +- .../model_providers/zhipuai/zhipuai.py | 7 +- .../zhipuai/zhipuai_sdk/__init__.py | 1 - .../zhipuai/zhipuai_sdk/__version__.py | 3 +- .../zhipuai/zhipuai_sdk/_client.py | 21 +- .../api_resource/chat/async_completions.py | 48 +- .../api_resource/chat/completions.py | 36 +- .../zhipuai_sdk/api_resource/embeddings.py | 24 +- .../zhipuai/zhipuai_sdk/api_resource/files.py | 33 +- .../api_resource/fine_tuning/fine_tuning.py | 1 - .../api_resource/fine_tuning/jobs.py | 52 +- .../zhipuai_sdk/api_resource/images.py | 36 +- .../zhipuai/zhipuai_sdk/core/_errors.py | 29 +- .../zhipuai/zhipuai_sdk/core/_http_client.py | 193 +-- .../zhipuai/zhipuai_sdk/core/_request_opt.py | 13 +- .../zhipuai/zhipuai_sdk/core/_response.py | 18 +- .../zhipuai/zhipuai_sdk/core/_sse_client.py | 53 +- .../types/chat/async_chat_completion.py | 2 +- .../zhipuai_sdk/types/chat/chat_completion.py | 2 - .../zhipuai/zhipuai_sdk/types/file_object.py | 2 - .../types/fine_tuning/fine_tuning_job.py | 2 +- .../schema_validators/common_validator.py | 27 +- .../model_credential_schema_validator.py | 1 - .../provider_credential_schema_validator.py | 1 - api/core/model_runtime/utils/encoders.py | 19 +- api/core/model_runtime/utils/helper.py | 2 +- api/core/moderation/api/api.py | 32 +- api/core/moderation/base.py | 5 +- api/core/moderation/input_moderation.py | 12 +- api/core/moderation/keywords/keywords.py | 24 +- .../openai_moderation/openai_moderation.py | 31 +- api/core/moderation/output_moderation.py | 44 +- api/core/ops/base_trace_instance.py | 2 +- api/core/ops/entities/config_entity.py | 23 +- api/core/ops/entities/trace_entity.py | 32 +- .../entities/langsmith_trace_entity.py | 62 +- .../ops/langsmith_trace/langsmith_trace.py | 14 +- api/core/ops/ops_trace_manager.py | 157 +- api/core/ops/utils.py | 8 +- api/core/prompt/advanced_prompt_transform.py | 140 +- .../prompt/agent_history_prompt_transform.py | 22 +- .../entities/advanced_prompt_entities.py | 9 +- .../advanced_prompt_templates.py | 64 +- api/core/prompt/prompt_transform.py | 75 +- api/core/prompt/simple_prompt_transform.py | 206 +-- api/core/prompt/utils/prompt_message_util.py | 67 +- .../prompt/utils/prompt_template_parser.py | 4 +- api/core/provider_manager.py | 348 ++-- api/core/rag/cleaner/clean_processor.py | 30 +- api/core/rag/cleaner/cleaner_base.py | 5 +- .../unstructured_extra_whitespace_cleaner.py | 2 +- ...uctured_group_broken_paragraphs_cleaner.py | 2 +- .../unstructured_non_ascii_chars_cleaner.py | 2 +- ...ructured_replace_unicode_quotes_cleaner.py | 3 +- .../unstructured_translate_text_cleaner.py | 2 +- .../data_post_processor.py | 51 +- api/core/rag/data_post_processor/reorder.py | 1 - .../rag/datasource/keyword/jieba/jieba.py | 136 +- .../jieba/jieba_keyword_table_handler.py | 3 +- .../rag/datasource/keyword/jieba/stopwords.py | 1466 ++++++++++++++++- .../rag/datasource/keyword/keyword_base.py | 10 +- .../rag/datasource/keyword/keyword_factory.py | 9 +- api/core/rag/datasource/retrieval_service.py | 237 +-- .../vdb/analyticdb/analyticdb_vector.py | 55 +- .../datasource/vdb/chroma/chroma_vector.py | 44 +- .../vdb/elasticsearch/elasticsearch_vector.py | 119 +- .../datasource/vdb/milvus/milvus_vector.py | 114 +- .../datasource/vdb/myscale/myscale_vector.py | 22 +- .../vdb/opensearch/opensearch_vector.py | 99 +- .../rag/datasource/vdb/oracle/oraclevector.py | 56 +- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 48 +- .../rag/datasource/vdb/pgvector/pgvector.py | 5 +- .../datasource/vdb/qdrant/qdrant_vector.py | 172 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 60 +- .../datasource/vdb/tencent/tencent_vector.py | 59 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 75 +- api/core/rag/datasource/vdb/vector_base.py | 16 +- api/core/rag/datasource/vdb/vector_factory.py | 54 +- api/core/rag/datasource/vdb/vector_type.py | 28 +- .../vdb/weaviate/weaviate_vector.py | 80 +- api/core/rag/docstore/dataset_docstore.py | 70 +- api/core/rag/extractor/blob/blob.py | 1 + api/core/rag/extractor/csv_extractor.py | 17 +- .../rag/extractor/entity/extract_setting.py | 3 + api/core/rag/extractor/excel_extractor.py | 35 +- api/core/rag/extractor/extract_processor.py | 103 +- api/core/rag/extractor/extractor_base.py | 5 +- .../rag/extractor/firecrawl/firecrawl_app.py | 102 +- .../firecrawl/firecrawl_web_extractor.py | 52 +- api/core/rag/extractor/helpers.py | 4 +- api/core/rag/extractor/html_extractor.py | 13 +- api/core/rag/extractor/markdown_extractor.py | 20 +- api/core/rag/extractor/notion_extractor.py | 143 +- api/core/rag/extractor/pdf_extractor.py | 15 +- api/core/rag/extractor/text_extractor.py | 8 +- .../unstructured_doc_extractor.py | 14 +- .../unstructured_eml_extractor.py | 6 +- .../unstructured_epub_extractor.py | 1 + .../unstructured_markdown_extractor.py | 1 + .../unstructured_msg_extractor.py | 7 +- .../unstructured_ppt_extractor.py | 7 +- .../unstructured_pptx_extractor.py | 6 +- .../unstructured_text_extractor.py | 7 +- .../unstructured_xml_extractor.py | 7 +- api/core/rag/extractor/word_extractor.py | 77 +- .../index_processor/index_processor_base.py | 34 +- .../index_processor_factory.py | 4 +- .../processor/paragraph_index_processor.py | 48 +- .../processor/qa_index_processor.py | 86 +- api/core/rag/models/document.py | 8 +- api/core/rag/rerank/constants/rerank_mode.py | 6 +- api/core/rag/rerank/rerank_model.py | 32 +- api/core/rag/rerank/weight_rerank.py | 39 +- api/core/rag/retrieval/dataset_retrieval.py | 385 +++-- .../output_parser/structured_chat.py | 4 +- api/core/rag/retrieval/retrieval_methods.py | 6 +- .../multi_dataset_function_call_router.py | 24 +- .../router/multi_dataset_react_route.py | 120 +- api/core/rag/splitter/fixed_text_splitter.py | 19 +- api/core/rag/splitter/text_splitter.py | 129 +- api/core/tools/entities/api_entities.py | 49 +- api/core/tools/entities/common_entities.py | 7 +- api/core/tools/entities/tool_bundle.py | 1 + api/core/tools/entities/tool_entities.py | 177 +- api/core/tools/entities/values.py | 132 +- api/core/tools/errors.py | 9 +- api/core/tools/provider/api_tool_provider.py | 172 +- api/core/tools/provider/app_tool_provider.py | 120 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../tools/provider/builtin/aippt/aippt.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 479 +++--- .../builtin/alphavantage/alphavantage.py | 2 +- .../builtin/alphavantage/tools/query_stock.py | 31 +- .../tools/provider/builtin/arxiv/arxiv.py | 3 +- .../builtin/arxiv/tools/arxiv_search.py | 16 +- api/core/tools/provider/builtin/aws/aws.py | 11 +- .../builtin/aws/tools/apply_guardrail.py | 37 +- .../aws/tools/lambda_translate_utils.py | 79 +- .../builtin/aws/tools/lambda_yaml_to_json.py | 47 +- .../aws/tools/sagemaker_text_rerank.py | 58 +- .../builtin/aws/tools/sagemaker_tts.py | 92 +- .../provider/builtin/azuredalle/azuredalle.py | 8 +- .../builtin/azuredalle/tools/dalle3.py | 67 +- .../builtin/bing/tools/bing_web_search.py | 198 ++- .../tools/provider/builtin/brave/brave.py | 3 +- .../builtin/brave/tools/brave_search.py | 17 +- .../tools/provider/builtin/chart/chart.py | 41 +- .../tools/provider/builtin/chart/tools/bar.py | 27 +- .../provider/builtin/chart/tools/line.py | 31 +- .../tools/provider/builtin/chart/tools/pie.py | 28 +- .../builtin/code/tools/simple_code.py | 14 +- .../tools/provider/builtin/cogview/cogview.py | 11 +- .../builtin/cogview/tools/cogview3.py | 59 +- .../provider/builtin/crossref/crossref.py | 4 +- .../builtin/crossref/tools/query_doi.py | 11 +- .../builtin/crossref/tools/query_title.py | 73 +- .../tools/provider/builtin/dalle/dalle.py | 9 +- .../provider/builtin/dalle/tools/dalle2.py | 55 +- .../provider/builtin/dalle/tools/dalle3.py | 73 +- .../tools/provider/builtin/devdocs/devdocs.py | 3 +- .../builtin/devdocs/tools/searchDevDocs.py | 16 +- api/core/tools/provider/builtin/did/did.py | 9 +- .../tools/provider/builtin/did/did_appx.py | 46 +- .../provider/builtin/did/tools/animations.py | 36 +- .../tools/provider/builtin/did/tools/talks.py | 62 +- .../dingtalk/tools/dingtalk_group_bot.py | 62 +- .../provider/builtin/duckduckgo/duckduckgo.py | 3 +- .../builtin/duckduckgo/tools/ddgo_ai.py | 4 +- .../builtin/duckduckgo/tools/ddgo_img.py | 17 +- .../builtin/duckduckgo/tools/ddgo_search.py | 17 +- .../duckduckgo/tools/ddgo_translate.py | 6 +- .../builtin/feishu/tools/feishu_group_bot.py | 31 +- .../builtin/feishu_base/feishu_base.py | 2 +- .../feishu_base/tools/add_base_record.py | 42 +- .../builtin/feishu_base/tools/create_base.py | 26 +- .../feishu_base/tools/create_base_table.py | 34 +- .../feishu_base/tools/delete_base_records.py | 42 +- .../feishu_base/tools/delete_base_tables.py | 29 +- .../feishu_base/tools/get_base_info.py | 21 +- .../tools/get_tenant_access_token.py | 24 +- .../feishu_base/tools/list_base_records.py | 46 +- .../feishu_base/tools/list_base_tables.py | 25 +- .../feishu_base/tools/read_base_record.py | 34 +- .../feishu_base/tools/update_base_record.py | 46 +- .../feishu_document/feishu_document.py | 6 +- .../feishu_document/tools/create_document.py | 10 +- .../tools/get_document_raw_content.py | 8 +- .../tools/list_document_block.py | 10 +- .../feishu_document/tools/write_document.py | 10 +- .../builtin/feishu_message/feishu_message.py | 6 +- .../feishu_message/tools/send_bot_message.py | 12 +- .../tools/send_webhook_message.py | 12 +- .../provider/builtin/firecrawl/firecrawl.py | 11 +- .../builtin/firecrawl/firecrawl_appx.py | 61 +- .../provider/builtin/firecrawl/tools/crawl.py | 52 +- .../builtin/firecrawl/tools/crawl_job.py | 17 +- .../builtin/firecrawl/tools/scrape.py | 40 +- .../builtin/firecrawl/tools/search.py | 19 +- .../tools/provider/builtin/gaode/gaode.py | 14 +- .../builtin/gaode/tools/gaode_weather.py | 67 +- .../provider/builtin/getimgai/getimgai.py | 9 +- .../builtin/getimgai/getimgai_appx.py | 22 +- .../builtin/getimgai/tools/text2image.py | 34 +- .../tools/provider/builtin/github/github.py | 16 +- .../github/tools/github_repositories.py | 64 +- .../tools/provider/builtin/gitlab/gitlab.py | 18 +- .../builtin/gitlab/tools/gitlab_commits.py | 107 +- .../builtin/gitlab/tools/gitlab_files.py | 61 +- .../tools/provider/builtin/google/google.py | 8 +- .../builtin/google/tools/google_search.py | 23 +- .../google_translate/google_translate.py | 6 +- .../google_translate/tools/translate.py | 34 +- api/core/tools/provider/builtin/hap/hap.py | 2 +- .../builtin/hap/tools/add_worksheet_record.py | 37 +- .../hap/tools/delete_worksheet_record.py | 37 +- .../builtin/hap/tools/get_worksheet_fields.py | 107 +- .../hap/tools/get_worksheet_pivot_data.py | 117 +- .../hap/tools/list_worksheet_records.py | 193 ++- .../builtin/hap/tools/list_worksheets.py | 77 +- .../hap/tools/update_worksheet_record.py | 41 +- api/core/tools/provider/builtin/jina/jina.py | 38 +- .../builtin/jina/tools/jina_reader.py | 67 +- .../builtin/jina/tools/jina_search.py | 41 +- .../builtin/jina/tools/jina_tokenizer.py | 38 +- .../builtin/json_process/json_process.py | 11 +- .../builtin/json_process/tools/delete.py | 23 +- .../builtin/json_process/tools/insert.py | 45 +- .../builtin/json_process/tools/parse.py | 23 +- .../builtin/json_process/tools/replace.py | 55 +- .../provider/builtin/judge0ce/judge0ce.py | 3 +- .../builtin/judge0ce/tools/executeCode.py | 44 +- .../tools/provider/builtin/maths/maths.py | 4 +- .../builtin/maths/tools/eval_expression.py | 21 +- .../provider/builtin/nominatim/nominatim.py | 24 +- .../nominatim/tools/nominatim_lookup.py | 43 +- .../nominatim/tools/nominatim_reverse.py | 46 +- .../nominatim/tools/nominatim_search.py | 46 +- .../builtin/novitaai/_novita_tool_base.py | 24 +- .../provider/builtin/novitaai/novitaai.py | 38 +- .../novitaai/tools/novitaai_createtile.py | 27 +- .../novitaai/tools/novitaai_modelquery.py | 134 +- .../novitaai/tools/novitaai_txt2img.py | 65 +- .../tools/provider/builtin/onebot/onebot.py | 4 +- .../builtin/onebot/tools/send_group_msg.py | 53 +- .../builtin/onebot/tools/send_private_msg.py | 55 +- .../builtin/openweather/openweather.py | 13 +- .../builtin/openweather/tools/weather.py | 14 +- .../provider/builtin/perplexity/perplexity.py | 20 +- .../perplexity/tools/perplexity_search.py | 77 +- .../tools/provider/builtin/pubmed/pubmed.py | 3 +- .../builtin/pubmed/tools/pubmed_search.py | 40 +- .../tools/provider/builtin/qrcode/qrcode.py | 5 +- .../builtin/qrcode/tools/qrcode_generator.py | 41 +- .../tools/provider/builtin/regex/regex.py | 6 +- .../builtin/regex/tools/regex_extract.py | 21 +- .../provider/builtin/searchapi/searchapi.py | 7 +- .../builtin/searchapi/tools/google.py | 23 +- .../builtin/searchapi/tools/google_jobs.py | 34 +- .../builtin/searchapi/tools/google_news.py | 23 +- .../searchapi/tools/youtube_transcripts.py | 17 +- .../tools/provider/builtin/searxng/searxng.py | 8 +- .../builtin/searxng/tools/searxng_search.py | 19 +- .../tools/provider/builtin/serper/serper.py | 7 +- .../builtin/serper/tools/serper_search.py | 30 +- .../builtin/siliconflow/siliconflow.py | 4 +- .../builtin/siliconflow/tools/flux.py | 12 +- .../siliconflow/tools/stable_diffusion.py | 8 +- .../builtin/slack/tools/slack_webhook.py | 27 +- .../tools/provider/builtin/spark/spark.py | 8 +- .../spark/tools/spark_img_generation.py | 39 +- .../tools/provider/builtin/spider/spider.py | 12 +- .../provider/builtin/spider/spiderApp.py | 38 +- .../builtin/spider/tools/scraper_crawler.py | 38 +- .../provider/builtin/stability/stability.py | 1 + .../provider/builtin/stability/tools/base.py | 19 +- .../builtin/stability/tools/text2image.py | 44 +- .../stablediffusion/stablediffusion.py | 1 - .../stablediffusion/tools/stable_diffusion.py | 308 ++-- .../builtin/stackexchange/stackexchange.py | 7 +- .../tools/fetchAnsByStackExQuesID.py | 8 +- .../tools/searchStackExQuestions.py | 14 +- .../tools/provider/builtin/stepfun/stepfun.py | 5 +- .../provider/builtin/stepfun/tools/image.py | 60 +- .../tools/provider/builtin/tavily/tavily.py | 5 +- .../builtin/tavily/tools/tavily_search.py | 28 +- .../provider/builtin/tianditu/tianditu.py | 12 +- .../builtin/tianditu/tools/geocoder.py | 32 +- .../builtin/tianditu/tools/poisearch.py | 67 +- .../builtin/tianditu/tools/staticmap.py | 59 +- api/core/tools/provider/builtin/time/time.py | 3 +- .../builtin/time/tools/current_time.py | 25 +- .../provider/builtin/time/tools/weekday.py | 21 +- .../builtin/trello/tools/create_board.py | 17 +- .../trello/tools/create_list_on_board.py | 19 +- .../trello/tools/create_new_card_on_board.py | 13 +- .../builtin/trello/tools/delete_board.py | 7 +- .../builtin/trello/tools/delete_card.py | 7 +- .../builtin/trello/tools/fetch_all_boards.py | 8 +- .../builtin/trello/tools/get_board_actions.py | 11 +- .../builtin/trello/tools/get_board_by_id.py | 7 +- .../builtin/trello/tools/get_board_cards.py | 7 +- .../trello/tools/get_filterd_board_cards.py | 13 +- .../trello/tools/get_lists_on_board.py | 7 +- .../builtin/trello/tools/update_board.py | 11 +- .../builtin/trello/tools/update_card.py | 10 +- .../tools/provider/builtin/trello/trello.py | 7 +- .../builtin/twilio/tools/send_message.py | 11 +- .../tools/provider/builtin/twilio/twilio.py | 3 +- .../tools/provider/builtin/vanna/vanna.py | 6 +- .../builtin/vectorizer/tools/test_data.py | 2 +- .../builtin/vectorizer/tools/vectorizer.py | 61 +- .../provider/builtin/vectorizer/vectorizer.py | 8 +- .../builtin/webscraper/tools/webscraper.py | 19 +- .../provider/builtin/webscraper/webscraper.py | 7 +- .../builtin/websearch/tools/job_search.py | 18 +- .../builtin/websearch/tools/news_search.py | 16 +- .../builtin/websearch/tools/scholar_search.py | 18 +- .../builtin/websearch/tools/web_search.py | 14 +- .../builtin/wecom/tools/wecom_group_bot.py | 37 +- .../wikipedia/tools/wikipedia_search.py | 1 - .../provider/builtin/wikipedia/wikipedia.py | 3 +- .../wolframalpha/tools/wolframalpha.py | 74 +- .../builtin/wolframalpha/wolframalpha.py | 3 +- .../provider/builtin/yahoo/tools/analytics.py | 60 +- .../provider/builtin/yahoo/tools/news.py | 37 +- .../provider/builtin/yahoo/tools/ticker.py | 19 +- .../tools/provider/builtin/yahoo/yahoo.py | 3 +- .../provider/builtin/youtube/tools/videos.py | 67 +- .../tools/provider/builtin/youtube/youtube.py | 3 +- .../tools/provider/builtin_tool_provider.py | 152 +- api/core/tools/provider/tool_provider.py | 132 +- .../tools/provider/workflow_tool_provider.py | 137 +- api/core/tools/tool/api_tool.py | 189 ++- api/core/tools/tool/builtin_tool.py | 86 +- .../dataset_multi_retriever_tool.py | 157 +- .../dataset_retriever_base_tool.py | 1 + .../dataset_retriever_tool.py | 140 +- api/core/tools/tool/dataset_retriever_tool.py | 51 +- api/core/tools/tool/tool.py | 143 +- api/core/tools/tool/workflow_tool.py | 105 +- api/core/tools/tool_engine.py | 199 ++- api/core/tools/tool_file_manager.py | 29 +- api/core/tools/tool_label_manager.py | 52 +- api/core/tools/tool_manager.py | 435 ++--- api/core/tools/utils/configuration.py | 59 +- api/core/tools/utils/feishu_api_utils.py | 16 +- api/core/tools/utils/message_transformer.py | 109 +- .../tools/utils/model_invocation_utils.py | 66 +- api/core/tools/utils/parser.py | 323 ++-- .../tools/utils/tool_parameter_converter.py | 32 +- api/core/tools/utils/web_reader_tool.py | 73 +- .../utils/workflow_configuration_sync.py | 26 +- api/core/tools/utils/yaml_utils.py | 4 +- .../callbacks/base_workflow_callback.py | 5 +- .../entities/base_node_data_entities.py | 4 +- api/core/workflow/entities/node_entities.py | 61 +- .../workflow/entities/variable_entities.py | 1 + api/core/workflow/entities/variable_pool.py | 12 +- .../workflow/entities/workflow_entities.py | 17 +- .../condition_handlers/base_handler.py | 10 +- .../branch_identify_handler.py | 5 +- .../condition_handlers/condition_handler.py | 8 +- .../condition_handlers/condition_manager.py | 16 +- .../workflow/graph_engine/entities/graph.py | 278 ++-- .../entities/graph_runtime_state.py | 2 +- .../graph_engine/entities/run_condition.py | 2 +- .../entities/runtime_route_state.py | 14 +- .../workflow/graph_engine/graph_engine.py | 308 ++-- api/core/workflow/nodes/answer/answer_node.py | 20 +- .../answer/answer_stream_generate_router.py | 77 +- .../nodes/answer/answer_stream_processor.py | 39 +- .../nodes/answer/base_stream_processor.py | 13 +- api/core/workflow/nodes/answer/entities.py | 10 +- api/core/workflow/nodes/base_node.py | 38 +- api/core/workflow/nodes/code/code_node.py | 168 +- api/core/workflow/nodes/code/entities.py | 7 +- api/core/workflow/nodes/end/end_node.py | 11 +- .../nodes/end/end_stream_generate_router.py | 75 +- .../nodes/end/end_stream_processor.py | 42 +- api/core/workflow/nodes/end/entities.py | 8 +- .../workflow/nodes/http_request/entities.py | 14 +- .../nodes/http_request/http_executor.py | 123 +- .../nodes/http_request/http_request_node.py | 50 +- api/core/workflow/nodes/if_else/entities.py | 1 + .../workflow/nodes/if_else/if_else_node.py | 32 +- api/core/workflow/nodes/iteration/entities.py | 15 +- .../nodes/iteration/iteration_node.py | 138 +- .../nodes/iteration/iteration_start_node.py | 12 +- .../nodes/knowledge_retrieval/entities.py | 17 +- .../knowledge_retrieval_node.py | 224 ++- api/core/workflow/nodes/llm/entities.py | 14 +- api/core/workflow/nodes/llm/llm_node.py | 305 ++-- api/core/workflow/nodes/loop/entities.py | 4 +- api/core/workflow/nodes/loop/loop_node.py | 17 +- .../nodes/parameter_extractor/entities.py | 59 +- .../parameter_extractor_node.py | 460 +++--- .../nodes/parameter_extractor/prompts.py | 154 +- .../nodes/question_classifier/entities.py | 7 +- .../question_classifier_node.py | 182 +- .../question_classifier/template_prompts.py | 2 - api/core/workflow/nodes/start/entities.py | 1 + api/core/workflow/nodes/start/start_node.py | 14 +- .../nodes/template_transform/entities.py | 5 +- .../template_transform_node.py | 40 +- api/core/workflow/nodes/tool/entities.py | 41 +- api/core/workflow/nodes/tool/tool_node.py | 139 +- .../nodes/variable_aggregator/entities.py | 10 +- .../variable_aggregator_node.py | 25 +- .../nodes/variable_assigner/__init__.py | 6 +- .../workflow/nodes/variable_assigner/node.py | 26 +- .../nodes/variable_assigner/node_data.py | 10 +- api/core/workflow/utils/condition/entities.py | 19 +- .../workflow/utils/condition/processor.py | 52 +- .../utils/variable_template_parser.py | 20 +- api/core/workflow/workflow_entry.py | 121 +- api/pyproject.toml | 1 - 724 files changed, 21195 insertions(+), 21138 deletions(-) diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd8afa..6eaea7b1c8419f 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 89c948d2e29f5f..29b428a7c32461 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,17 +25,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True - _ignore_observation_providers = ['wenxin'] + _ignore_observation_providers = ["wenxin"] _historic_prompt_messages: list[PromptMessage] = None _agent_scratchpad: list[AgentScratchpadUnit] = None _instruction: str = None _query: str = None _prompt_messages_tools: list[PromptMessage] = None - def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + def run( + self, + message: Message, + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -46,17 +48,16 @@ def run(self, message: Message, trace_manager = app_generate_entity.trace_manager # check model mode - if 'Observation' not in app_generate_entity.model_conf.stop: + if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append('Observation') + app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools( - instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -65,16 +66,14 @@ def run(self, message: Message, tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -94,17 +93,13 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # recalc llm max tokens prompt_messages = self._organize_prompt_messages() @@ -125,21 +120,20 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output( - chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( - agent_response='', - thought='', - action_str='', - observation='', + agent_response="", + thought="", + action_str="", + observation="", action=None, ) # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): @@ -154,61 +148,51 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=chunk - ), - usage=None - ) + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - scratchpad.thought = scratchpad.thought.strip( - ) or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) # get llm usage - if 'usage' in usage_dict: - increase_usage(llm_usage, usage_dict['usage']) + if "usage" in usage_dict: + increase_usage(llm_usage, usage_dict["usage"]) else: - usage_dict['usage'] = LLMUsage.empty_usage() + usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else {}, + tool_name=scratchpad.action.action_name if scratchpad.action else "", + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, thought=scratchpad.thought, - observation='', + observation="", answer=scratchpad.agent_response, messages_ids=[], - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) if not scratchpad.is_final(): - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = '' + final_answer = "" else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps( - scratchpad.action.action_input) + final_answer = json.dumps(scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" except json.JSONDecodeError: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True # action is tool call, invoke tool @@ -224,21 +208,18 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={ - scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={ - scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # update prompt tool message for prompt_tool in self._prompt_messages_tools: @@ -250,44 +231,45 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=final_answer - ), - usage=llm_usage['usage'] + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] ), - system_fingerprint='' + system_fingerprint="", ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name='', + tool_name="", tool_input={}, tool_invoke_meta={}, thought=final_answer, observation={}, answer=final_answer, - messages_ids=[] + messages_ids=[], ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) - - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], - message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, ToolInvokeMeta]: + PublishFrom.APPLICATION_MANAGER, + ) + + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action :param action: action @@ -326,13 +308,12 @@ def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file( - tool_name=tool_call_name, value=message_file_id, name=save_as) + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) @@ -342,10 +323,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ convert dict to action """ - return AgentScratchpadUnit.Action( - action_name=action['action'], - action_input=action['action_input'] - ) + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -353,7 +331,7 @@ def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dic """ for key, value in inputs.items(): try: - instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception as e: continue @@ -370,14 +348,14 @@ def _init_react_state(self, query) -> None: @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ - organize prompt messages + organize prompt messages """ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ - format assistant message + format assistant message """ - message = '' + message = "" for scratchpad in agent_scratchpad: if scratchpad.is_final(): message += f"Final Answer: {scratchpad.agent_response}" @@ -390,9 +368,11 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) return message - def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_historic_prompt_messages( + self, current_session_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ - organize historic prompt messages + organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] @@ -403,8 +383,8 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom if not current_scratchpad: current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content or 'I am thinking about how to help you', - action_str='', + thought=message.content or "I am thinking about how to help you", + action_str="", action=None, observation=None, ) @@ -413,12 +393,9 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads( - message.tool_calls[0].function.arguments) - ) - current_scratchpad.action_str = json.dumps( - current_scratchpad.action.to_dict() + action_input=json.loads(message.tool_calls[0].function.arguments), ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) except: pass elif isinstance(message, ToolPromptMessage): @@ -426,23 +403,19 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None result.append(message) if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], history_messages=result, - memory=self.memory + memory=self.memory, ).get_prompt() return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 8debbe5c5ddf0f..bdec6b7ed15d7b 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,14 +19,15 @@ def _organize_system_prompt(self) -> SystemPromptMessage: prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt \ - .replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -43,7 +44,7 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No def _organize_prompt_messages(self) -> list[PromptMessage]: """ - Organize + Organize """ # organize system prompt system_message = self._organize_system_prompt() @@ -53,7 +54,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content='') + assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" @@ -71,18 +72,15 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: if assistant_messages: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([ - system_message, - *query_messages, - *assistant_messages, - UserPromptMessage(content='continue') - ]) + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) messages = [ system_message, *historic_messages, *query_messages, *assistant_messages, - UserPromptMessage(content='continue') + UserPromptMessage(content="continue"), ] else: # organize historic prompt messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 9e6eb54f4fe513..9dab956f9a4590 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -13,10 +13,12 @@ def _organize_instruction_prompt(self) -> str: prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) - + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + return system_prompt def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: @@ -46,7 +48,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad - assistant_prompt = '' + assistant_prompt = "" for unit in agent_scratchpad: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" @@ -61,9 +63,10 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: query_prompt = f"Question: {self._query}" # join all messages - prompt = system_prompt \ - .replace("{{historic_messages}}", historic_prompt) \ - .replace("{{agent_scratchpad}}", assistant_prompt) \ + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) + ) - return [UserPromptMessage(content=prompt)] \ No newline at end of file + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5274224de5772c..119a88fc7becbf 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str @@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -31,6 +33,7 @@ class Action(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -39,8 +42,8 @@ def to_dict(self) -> dict: Convert to dictionary. """ return { - 'action': self.action_name, - 'action_input': self.action_input, + "action": self.action_name, + "action_input": self.action_input, } agent_response: Optional[str] = None @@ -54,10 +57,10 @@ def is_final(self) -> bool: Check if the scratchpad unit is final. """ return self.action is None or ( - 'final' in self.action.action_name.lower() and - 'answer' in self.action.action_name.lower() + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() ) + class AgentEntity(BaseModel): """ Agent Entity. @@ -67,8 +70,9 @@ class Strategy(Enum): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3ee6e47742a18f..27cf561e3d61ac 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -24,11 +24,9 @@ logger = logging.getLogger(__name__) -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, - message: Message, query: str, **kwargs: Any - ) -> Generator[LLMResultChunk, None, None]: +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ @@ -45,19 +43,17 @@ def run(self, # continue to run until there is not any tool call function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -75,11 +71,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) # recalc llm max tokens @@ -99,11 +91,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response - response = '' + response = "" # save tool call names and inputs - tool_call_names = '' - tool_call_inputs = '' + tool_call_names = "" + tool_call_inputs = "" current_llm_usage = None @@ -111,24 +103,22 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True tool_calls.extend(self.extract_tool_calls(chunk)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): @@ -148,16 +138,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if self.check_blocking_tool_calls(result): function_call_state = True tool_calls.extend(self.extract_blocking_tool_calls(result)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: increase_usage(llm_usage, result.usage) @@ -171,12 +159,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): response += result.message.content if not result.message.content: - result.message.content = '' + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,32 +173,29 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): index=0, message=result.message, usage=result.usage, - ) + ), ) - assistant_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: - assistant_message.tool_calls=[ + assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=tool_call[0], - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], - arguments=json.dumps(tool_call[2], ensure_ascii=False) - ) - ) for tool_call in tool_calls + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls ] else: assistant_message.content = response - + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -218,13 +203,13 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): observation=None, answer=response, messages_ids=[], - llm_usage=current_llm_usage + llm_usage=current_llm_usage, ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - - final_answer += response + '\n' + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" # call tools tool_responses = [] @@ -235,7 +220,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), } else: # invoke tool @@ -255,50 +240,49 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) - + tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict() + "meta": tool_invoke_meta.to_dict(), } - + tool_responses.append(tool_response) - if tool_response['tool_response'] is not None: + if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response['tool_response'], + content=tool_response["tool_response"], tool_call_id=tool_call_id, name=tool_call_name, ) - ) + ) if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, + thought=None, tool_invoke_meta={ - tool_response['tool_call_name']: tool_response['meta'] - for tool_response in tool_responses + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, observation={ - tool_response['tool_call_name']: tool_response['tool_response'] + tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, answer=None, - messages_ids=message_file_ids + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -308,15 +292,18 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ @@ -325,7 +312,7 @@ def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: if llm_result_chunk.delta.message.tool_calls: return True return False - + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: """ Check if there is any blocking tool call in llm result @@ -334,7 +321,9 @@ def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -344,17 +333,19 @@ def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, li tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result @@ -365,18 +356,22 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list tool_calls = [] for prompt_message in llm_result.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _init_system_message( + self, prompt_template: str, prompt_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ Initialize system message """ @@ -384,13 +379,13 @@ def _init_system_message(self, prompt_template: str, prompt_messages: list[Promp return [ SystemPromptMessage(content=prompt_template), ] - + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -404,7 +399,7 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages - + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -415,17 +410,21 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - content.data if content.type == PromptMessageContentType.TEXT else - '[image]' if content.type == PromptMessageContentType.IMAGE else - '[file]' - for content in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) return prompt_messages def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query, []) @@ -433,14 +432,10 @@ def _organize_prompt_messages(self): model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, - memory=self.memory + memory=self.memory, ).get_prompt() - prompt_messages = [ - *self.history_prompt_messages, - *query_prompt_messages, - *self._current_thoughts - ] + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c53fa5000e9958..1a161677dda345 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,8 +9,9 @@ class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ - Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: action = json.loads(json_str) @@ -22,7 +23,7 @@ def parse_action(json_str): action = action[0] for key, value in action.items(): - if 'input' in key.lower(): + if "input" in key.lower(): action_input = value else: action_name = value @@ -33,37 +34,37 @@ def parse_action(json_str): action_input=action_input, ) else: - return json_str or '' + return json_str or "" except: - return json_str or '' - + return json_str or "" + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) yield parse_action(json_text) - - code_block_cache = '' + + code_block_cache = "" code_block_delimiter_count = 0 in_code_block = False - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False got_json = False - action_cache = '' - action_str = 'action:' + action_cache = "" + action_str = "action:" action_idx = 0 - thought_cache = '' - thought_str = 'thought:' + thought_cache = "" + thought_str = "thought:" thought_idx = 0 for response in llm_response: if response.delta.usage: - usage_dict['usage'] = response.delta.usage + usage_dict["usage"] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue @@ -72,24 +73,24 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] - last_character = response[index-1] if index > 0 else '' + delta = response[index : index + steps] + last_character = response[index - 1] if index > 0 else "" - if delta == '`': + if delta == "`": code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: yield code_block_cache - code_block_cache = '' + code_block_cache = "" else: code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -97,7 +98,7 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue @@ -105,18 +106,18 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue else: if action_cache: yield action_cache - action_cache = '' + action_cache = "" action_idx = 0 - + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -124,7 +125,7 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue @@ -132,31 +133,31 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue else: if thought_cache: yield thought_cache - thought_cache = '' + thought_cache = "" thought_idx = 0 if code_block_delimiter_count == 3: if in_code_block: yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - + code_block_cache = "" + in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json - if delta == '{': + if delta == "{": json_quote_count += 1 in_json = True json_cache += delta - elif delta == '}': + elif delta == "}": json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -172,12 +173,12 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, if got_json: got_json = False yield parse_action(json_cache) - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False - + if not in_code_block and not in_json: - yield delta.replace('`', '') + yield delta.replace("`", "") index += steps @@ -186,4 +187,3 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, if json_cache: yield parse_action(json_cache) - diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index b0cf1a77fb1772..cb98f5501ddd08 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -91,14 +91,14 @@ ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } } -} \ No newline at end of file +} diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 3dea305e984143..0fd2a779a4800d 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -26,34 +26,24 @@ def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() - additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( - config=config_dict - ) + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, - is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] ) - additional_features.opening_statement, additional_features.suggested_questions = \ - OpeningStatementConfigManager.convert( - config=config_dict - ) + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( config=config_dict ) - additional_features.more_like_this = MoreLikeThisConfigManager.convert( - config=config_dict - ) + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) - additional_features.speech_to_text = SpeechToTextConfigManager.convert( - config=config_dict - ) + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) - additional_features.text_to_speech = TextToSpeechConfigManager.convert( - config=config_dict - ) + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) return additional_features diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 1ca8b1e3b8fed2..037037e6ca1cf0 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -7,25 +7,24 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: - sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None - if sensitive_word_avoidance_dict.get('enabled'): + if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), ) else: return None @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ - -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): - config["sensitive_word_avoidance"] = { - "enabled": False - } + config["sensitive_word_avoidance"] = {"enabled": False} if not isinstance(config["sensitive_word_avoidance"], dict): raise ValueError("sensitive_word_avoidance must be of dict type") @@ -41,10 +40,6 @@ def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_valid typ = config["sensitive_word_avoidance"]["type"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config - ) + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index dc65d4439b6e03..6e89f19508c01f 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -12,67 +12,70 @@ def convert(cls, config: dict) -> Optional[AgentEntity]: :param config: model config args """ - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode']: + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") - agent_dict = config.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': + if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': + elif agent_strategy == "cot" or agent_strategy == "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy - if config['model']['provider'] == 'openai': + if config["model"]["provider"] == "openai": strategy = AgentEntity.Strategy.FUNCTION_CALLING else: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) >= 4: if "enabled" not in tool or not tool["enabled"]: continue agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool.get('tool_parameters', {}) + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if 'strategy' in config['agent_mode'] and \ - config['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + "react_router", + "router", + ]: + agent_prompt = agent_dict.get("prompt", None) or {} # check model mode - model_mode = config.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), ) return AgentEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) + max_iteration=agent_dict.get("max_iteration", 5), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 1a621d2090f759..ff131b62e27543 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -15,39 +15,38 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: :param config: model config args """ dataset_ids = [] - if 'datasets' in config.get('dataset_configs', {}): - datasets = config.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) - for dataset in datasets.get('datasets', []): + for dataset in datasets.get("datasets", []): keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': + if len(keys) == 0 or keys[0] != "dataset": continue - dataset = dataset['dataset'] + dataset = dataset["dataset"] - if 'enabled' not in dataset or not dataset['enabled']: + if "enabled" not in dataset or not dataset["enabled"]: continue - dataset_id = dataset.get('id', None) + dataset_id = dataset.get("id", None) if dataset_id: dataset_ids.append(dataset_id) - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode'] \ - and config['agent_mode']['enabled']: + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) - agent_dict = config.get('agent_mode', {}) - - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) == 1: # old standard key = list(tool.keys())[0] - if key != 'dataset': + if key != "dataset": continue tool_item = tool[key] @@ -55,30 +54,28 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: if "enabled" not in tool_item or not tool_item["enabled"]: continue - dataset_id = tool_item['id'] + dataset_id = tool_item["id"] dataset_ids.append(dataset_id) if len(dataset_ids) == 0: return None # dataset configs - if 'dataset_configs' in config and config.get('dataset_configs'): - dataset_configs = config.get('dataset_configs') + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") else: - dataset_configs = { - 'retrieval_model': 'multiple' - } - query_variable = config.get('dataset_query_variable') + dataset_configs = {"retrieval_model": "multiple"} + query_variable = config.get("dataset_query_variable") - if dataset_configs['retrieval_model'] == 'single': + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) + dataset_configs["retrieval_model"] + ), + ), ) else: return DatasetEntity( @@ -86,15 +83,15 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] + dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get('top_k', 4), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights'), - reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), - ) + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), ) @classmethod @@ -111,13 +108,10 @@ def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: di # dataset_configs if not config.get("dataset_configs"): - config["dataset_configs"] = {'retrieval_model': 'single'} + config["dataset_configs"] = {"retrieval_model": "single"} if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") @@ -125,8 +119,9 @@ def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: di if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = (config.get("dataset_configs") - and config["dataset_configs"].get("datasets", {}).get("datasets")) + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -148,10 +143,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod """ # Extract dataset config for legacy compatibility if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -188,7 +180,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod if not isinstance(tool_item["enabled"], bool): raise ValueError("enabled in agent_mode.tools must be of boolean type") - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5c9b2cfec7babf..a91b9f0f020073 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,9 +11,7 @@ class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, - skip_check: bool = False) \ - -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -25,9 +23,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=app_config.tenant_id, - provider=model_config.provider, - model_type=ModelType.LLM + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) provider_name = provider_model_bundle.configuration.provider.provider @@ -38,8 +34,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, # check model credentials model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model_config.model + model_type=ModelType.LLM, model=model_config.model ) if model_credentials is None: @@ -51,8 +46,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, if not skip_check: # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, - model_type=ModelType.LLM + model=model_config.model, model_type=ModelType.LLM ) if provider_model is None: @@ -69,24 +63,18 @@ def convert(cls, app_config: EasyUIBasedAppConfig, # model config completion_params = model_config.parameters stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = model_config.mode if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=model_config.model, - credentials=model_credentials - ) + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = mode_enum.value - model_schema = model_type_instance.get_model_schema( - model_config.model, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 730a9527cf7315..b5e4554181c06e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -13,23 +13,23 @@ def convert(cls, config: dict) -> ModelConfigEntity: :param config: model config args """ # model config - model_config = config.get('model') + model_config = config.get("model") if not model_config: raise ValueError("model is required") - completion_params = model_config.get('completion_params') + completion_params = model_config.get("completion_params") stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode - model_mode = model_config.get('mode') + model_mode = model_config.get("mode") return ModelConfigEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], mode=model_mode, parameters=completion_params, stop=stop, @@ -43,7 +43,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, :param tenant_id: tenant id :param config: app model config args """ - if 'model' not in config: + if "model" not in config: raise ValueError("model is required") if not isinstance(config["model"], dict): @@ -52,17 +52,16 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, # model.provider provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name - if 'name' not in config["model"]: + if "name" not in config["model"]: raise ValueError("model.name is required") provider_manager = ProviderManager() models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM + provider=config["model"]["provider"], model_type=ModelType.LLM ) if not models: @@ -80,12 +79,12 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, # model.mode if model_mode: - config['model']["mode"] = model_mode + config["model"]["mode"] = model_mode else: - config['model']["mode"] = "completion" + config["model"]["mode"] = "completion" # model.completion_params - if 'completion_params' not in config["model"]: + if "completion_params" not in config["model"]: raise ValueError("model.completion_params is required") config["model"]["completion_params"] = cls.validate_model_completion_params( @@ -101,7 +100,7 @@ def validate_model_completion_params(cls, cp: dict) -> dict: raise ValueError("model.completion_params must be of object type") # stop - if 'stop' not in cp: + if "stop" not in cp: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 1f410758aa41da..de91c9a0657bdf 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -14,39 +14,33 @@ def convert(cls, config: dict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") - prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: simple_prompt_template = config.get("pre_prompt", "") - return PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) else: advanced_chat_prompt_template = None chat_prompt_config = config.get("chat_prompt_config", {}) if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) + chat_prompt_messages.append( + {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], + "prompt": completion_prompt_config["prompt"]["text"], } - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( @@ -56,7 +50,7 @@ def convert(cls, config: dict) -> PromptTemplateEntity: return PromptTemplateEntity( prompt_type=prompt_type, advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template + advanced_completion_prompt_template=advanced_completion_prompt_template, ) @classmethod @@ -72,7 +66,7 @@ def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dic config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] - if config['prompt_type'] not in prompt_type_vals: + if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") # chat_prompt_config @@ -89,27 +83,28 @@ def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dic if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required " - "when prompt_type is advanced") + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced" + ) model_mode_vals = [mode.value for mode in ModelMode] - if config['model']["mode"] not in model_mode_vals: + if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config['model']["mode"] == ModelMode.CHAT.value: - prompt_list = config['chat_prompt_config']['prompt'] + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: raise ValueError("prompt messages must be less than 10") diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 15fa4d99fd7bfc..2c0232c7430960 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -16,32 +16,30 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV variable_entities = [] # old external_data_tools - external_data_tools = config.get('external_data_tools', []) + external_data_tools = config.get("external_data_tools", []) for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: continue external_data_variables.append( ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], ) ) # variables and external_data_tools - for variables in config.get('user_input_form', []): + for variables in config.get("user_input_form", []): variable_type = list(variables.keys())[0] if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: variable = variables[variable_type] - if 'config' not in variable: + if "config" not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=variable['variable'], - type=variable['type'], - config=variable['config'] + variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) elif variable_type in [ @@ -54,13 +52,13 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV variable_entities.append( VariableEntity( type=variable_type, - variable=variable.get('variable'), - description=variable.get('description'), - label=variable.get('label'), - required=variable.get('required', False), - max_length=variable.get('max_length'), - options=variable.get('options'), - default=variable.get('default'), + variable=variable.get("variable"), + description=variable.get("description"), + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options"), + default=variable.get("default"), ) ) @@ -103,13 +101,13 @@ def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[s raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] - if 'label' not in form_item: + if "label" not in form_item: raise ValueError("label is required in user_input_form") if not isinstance(form_item["label"], str): raise ValueError("label in user_input_form must be of string type") - if 'variable' not in form_item: + if "variable" not in form_item: raise ValueError("variable is required in user_input_form") if not isinstance(form_item["variable"], str): @@ -117,26 +115,24 @@ def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[s pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") + raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number") variables.append(form_item["variable"]) - if 'required' not in form_item or not form_item["required"]: + if "required" not in form_item or not form_item["required"]: form_item["required"] = False if not isinstance(form_item["required"], bool): raise ValueError("required in user_input_form must be of boolean type") if key == "select": - if 'options' not in form_item or not form_item["options"]: + if "options" not in form_item or not form_item["options"]: form_item["options"] = [] if not isinstance(form_item["options"], list): raise ValueError("options in user_input_form must be a list of strings") - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: raise ValueError("default value in user_input_form must be in the options list") return config, ["user_input_form"] @@ -168,10 +164,6 @@ def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: d typ = tool["type"] config = tool["config"] - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index bbb10d3d76429e..d208db2b01dc29 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str mode: Optional[str] = None @@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -43,6 +46,7 @@ class RolePrefixEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -60,11 +64,12 @@ class PromptType(Enum): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str) -> "PromptType": """ Get value of given mode. @@ -74,7 +79,7 @@ def value_of(cls, value: str) -> 'PromptType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None @@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -125,11 +131,12 @@ class RetrieveStrategy(Enum): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str) -> "RetrieveStrategy": """ Get value of given mode. @@ -139,25 +146,24 @@ def value_of(cls, value: str) -> 'RetrieveStrategy': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = .0 - rerank_mode: Optional[str] = 'reranking_model' + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True - - class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel): """ Tracing Config Entity. """ + enabled: bool tracing_provider: str - - class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileExtraConfig] = None opening_statement: Optional[str] = None @@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel): text_to_speech: Optional[TextToSpeechEntity] = None trace_config: Optional[TracingConfigEntity] = None + class AppConfig(BaseModel): """ Application Config Entity. """ + tenant_id: str app_id: str app_mode: AppMode @@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum): """ App Model Config From. """ - ARGS = 'args' - APP_LATEST_CONFIG = 'app-latest-config' - CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" class EasyUIBasedAppConfig(AppConfig): """ Easy UI Based App Config Entity. """ + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str app_model_config_dict: dict @@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 3da3c2eddb83f3..5f7fc99151d2e8 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -13,21 +13,19 @@ def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[ :param config: model config args :param is_vision: if True, the feature is vision feature """ - file_upload_dict = config.get('file_upload') + file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get('image'): - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + if file_upload_dict.get("image"): + if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: image_config = { - 'number_limits': file_upload_dict['image']['number_limits'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] + "number_limits": file_upload_dict["image"]["number_limits"], + "transfer_methods": file_upload_dict["image"]["transfer_methods"], } if is_vision: - image_config['detail'] = file_upload_dict['image']['detail'] + image_config["detail"] = file_upload_dict["image"]["detail"] - return FileExtraConfig( - image_config=image_config - ) + return FileExtraConfig(image_config=image_config) return None @@ -49,21 +47,21 @@ def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tupl if not config["file_upload"].get("image"): config["file_upload"]["image"] = {"enabled": False} - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] + if config["file_upload"]["image"]["enabled"]: + number_limits = config["file_upload"]["image"]["number_limits"] if number_limits < 1 or number_limits > 6: raise ValueError("number_limits must be in [1, 6]") if is_vision: - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: + detail = config["file_upload"]["image"]["detail"] + if detail not in ["high", "low"]: raise ValueError("detail must be in ['high', 'low']") - transfer_methods = config['file_upload']['image']['transfer_methods'] + transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ['remote_url', 'local_file']: + if method not in ["remote_url", "local_file"]: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index 2ba99a5c40d5fc..496e1beeecfa0f 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ more_like_this = False - more_like_this_dict = config.get('more_like_this') + more_like_this_dict = config.get("more_like_this") if more_like_this_dict: - if more_like_this_dict.get('enabled'): + if more_like_this_dict.get("enabled"): more_like_this = True return more_like_this @@ -22,9 +22,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("more_like_this"): - config["more_like_this"] = { - "enabled": False - } + config["more_like_this"] = {"enabled": False} if not isinstance(config["more_like_this"], dict): raise ValueError("more_like_this must be of dict type") diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 0d8a71bfcf4d8b..b4dacbc409044a 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,5 +1,3 @@ - - class OpeningStatementConfigManager: @classmethod def convert(cls, config: dict) -> tuple[str, list]: @@ -9,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get('opening_statement') + opening_statement = config.get("opening_statement") # suggested questions - suggested_questions_list = config.get('suggested_questions') + suggested_questions_list = config.get("suggested_questions") return opening_statement, suggested_questions_list diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index fca58e12e883da..d098abac2fa2e7 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -2,9 +2,9 @@ class RetrievalResourceConfigManager: @classmethod def convert(cls, config: dict) -> bool: show_retrieve_source = False - retriever_resource_dict = config.get('retriever_resource') + retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: - if retriever_resource_dict.get('enabled'): + if retriever_resource_dict.get("enabled"): show_retrieve_source = True return show_retrieve_source @@ -17,9 +17,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("retriever_resource"): - config["retriever_resource"] = { - "enabled": False - } + config["retriever_resource"] = {"enabled": False} if not isinstance(config["retriever_resource"], dict): raise ValueError("retriever_resource must be of dict type") diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index 88b4be25d3e216..e10ae03e043b78 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ speech_to_text = False - speech_to_text_dict = config.get('speech_to_text') + speech_to_text_dict = config.get("speech_to_text") if speech_to_text_dict: - if speech_to_text_dict.get('enabled'): + if speech_to_text_dict.get("enabled"): speech_to_text = True return speech_to_text @@ -22,9 +22,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("speech_to_text"): - config["speech_to_text"] = { - "enabled": False - } + config["speech_to_text"] = {"enabled": False} if not isinstance(config["speech_to_text"], dict): raise ValueError("speech_to_text must be of dict type") diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index c6cab012207b9c..9ac5114d12dd44 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ suggested_questions_after_answer = False - suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") if suggested_questions_after_answer_dict: - if suggested_questions_after_answer_dict.get('enabled'): + if suggested_questions_after_answer_dict.get("enabled"): suggested_questions_after_answer = True return suggested_questions_after_answer @@ -22,15 +22,15 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("suggested_questions_after_answer"): - config["suggested_questions_after_answer"] = { - "enabled": False - } + config["suggested_questions_after_answer"] = {"enabled": False} if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not \ - config["suggested_questions_after_answer"]["enabled"]: + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index f11e268e7380db..1c7598178527b4 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -10,13 +10,13 @@ def convert(cls, config: dict): :param config: model config args """ text_to_speech = None - text_to_speech_dict = config.get('text_to_speech') + text_to_speech_dict = config.get("text_to_speech") if text_to_speech_dict: - if text_to_speech_dict.get('enabled'): + if text_to_speech_dict.get("enabled"): text_to_speech = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), ) return text_to_speech @@ -29,11 +29,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("text_to_speech"): - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} if not isinstance(config["text_to_speech"], dict): raise ValueError("text_to_speech must be of dict type") diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index c3d0e8ba037cae..b52f235849f665 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): """ Advanced Chatbot App Config Entity. """ + pass class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_mode = AppMode.value_of(app_model.mode) @@ -34,13 +33,9 @@ def get_app_config(cls, app_model: App, app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -58,8 +53,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -69,7 +63,8 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -86,9 +81,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) @@ -98,4 +91,3 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: filtered_config = {key: config.get(key) for key in related_config_keys} return filtered_config - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 638cc07461d714..1277dcebc55ce0 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -34,7 +34,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -44,7 +45,8 @@ def generate( @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -53,14 +55,14 @@ def generate( ) -> dict: ... def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ) -> dict[str, Any] | Generator[str, Any, None]: + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -71,44 +73,37 @@ def generate( :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} # get conversation conversation = None - conversation_id = args.get('conversation_id') + conversation_id = args.get("conversation_id") if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -130,7 +125,7 @@ def generate( stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -140,16 +135,12 @@ def generate( invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream + stream=stream, ) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -161,16 +152,13 @@ def single_iteration_generate(self, app_model: App, :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( @@ -178,18 +166,15 @@ def single_iteration_generate(self, app_model: App, app_config=app_config, conversation_id=None, inputs={}, - query='', + query="", files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -199,17 +184,19 @@ def single_iteration_generate(self, app_model: App, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, conversation=None, - stream=stream + stream=stream, ) - def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Optional[Conversation] = None, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -225,10 +212,7 @@ def _generate(self, *, is_first_conversation = True # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -243,18 +227,21 @@ def _generate(self, *, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - 'context': contextvars.copy_context(), - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) worker_thread.start() @@ -269,17 +256,17 @@ def _generate(self, *, stream=stream, ) - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -302,7 +289,7 @@ def _generate_worker(self, flask_app: Flask, application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) runner.run() @@ -310,14 +297,13 @@ def _generate_worker(self, flask_app: Flask, pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 0caff4a2e395d7..d9fc5995423d82 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( - content_text=text_content.strip(), - user="responding_tts", - tenant_id=tenant_id, - voice=voice + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice ) @@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue): except Exception as e: logging.getLogger(__name__).warning(e) break - audio_queue.put(AudioTrunk("finish", b'')) + audio_queue.put(AudioTrunk("finish", b"")) class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.msg_text = '' + self.msg_text = "" self._audio_queue = queue.Queue() self._msg_queue = queue.Queue() - self.match = re.compile(r'[。.!?]') + self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.TTS + tenant_id=self.tenant_id, model_type=ModelType.TTS ) self.voices = self.model_instance.get_tts_voices() - values = [voice.get('value') for voice in self.voices] + values = [voice.get("value") for voice in self.voices] self.voice = voice if not voice or voice not in values: - self.voice = self.voices[0].get('value') + self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 self._last_audio_event = None self._runtime_thread = threading.Thread(target=self._runtime).start() @@ -85,8 +80,9 @@ def _runtime(self): message = self._msg_queue.get() if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: - futures_result = self.executor.submit(_invoiceTTS, self.msg_text, - self.model_instance, self.tenant_id, self.voice) + futures_result = self.executor.submit( + _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): @@ -94,21 +90,20 @@ def _runtime(self): elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): - self.msg_text += message.event.outputs.get('output', '') + self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): self.MAX_SENTENCE += 1 - text_content = ''.join(sentence_arr) - futures_result = self.executor.submit(_invoiceTTS, text_content, - self.model_instance, - self.tenant_id, - self.voice) + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) if text_tmp: self.msg_text = text_tmp else: - self.msg_text = '' + self.msg_text = "" except Exception as e: self.logger.warning(e) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 4da3d093d276c5..90f547b0f2a64e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, ) -> None: """ :param application_generate_entity: application generate entity @@ -66,11 +66,11 @@ def run(self) -> None: app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") user_id = None if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: @@ -81,7 +81,7 @@ def run(self) -> None: user_id = self.application_generate_entity.user_id workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) if self.application_generate_entity.single_iteration_run: @@ -89,7 +89,7 @@ def run(self) -> None: graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: inputs = self.application_generate_entity.inputs @@ -98,26 +98,27 @@ def run(self) -> None: # moderation if self.handle_input_moderation( - app_record=app_record, - app_generate_entity=self.application_generate_entity, - inputs=inputs, - query=query, - message_id=self.message.id + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=self.message, - query=query, - app_generate_entity=self.application_generate_entity + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, ): return # Init conversation variables stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: conversation_variables = session.scalars(stmt).all() @@ -190,12 +191,12 @@ def run(self) -> None: self._handle_event(workflow_entry, event) def handle_input_moderation( - self, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -217,18 +218,14 @@ def handle_input_moderation( message_id=message_id, ) except ModerationException as e: - self._complete_with_stream_output( - text=str(e), - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION - ) + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -246,32 +243,21 @@ def handle_annotation_reply(self, app_record: App, ) if annotation_reply: - self._publish_event( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) - ) + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) self._complete_with_stream_output( - text=annotation_reply.content, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _complete_with_stream_output(self, - text: str, - stopped_by: QueueStopEvent.StopBy) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output :param text: text :return: """ - self._publish_event( - QueueTextChunkEvent( - text=text - ) - ) + self._publish_event(QueueTextChunkEvent(text=text)) - self._publish_event( - QueueStopEvent(stopped_by=stopped_by) - ) + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index ef579827b47c7e..5fbd3e9a94906f 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -28,15 +28,15 @@ def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) """ blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -50,13 +50,15 @@ def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -67,14 +69,14 @@ def convert_stream_full_response(cls, stream_response: Generator[AppStreamRespon sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -85,7 +87,9 @@ def convert_stream_full_response(cls, stream_response: Generator[AppStreamRespon yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream simple response. :param stream_response: stream response @@ -96,20 +100,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResp sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fb013cd1b1f7ed..f3e1a49cc2243a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow @@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow_system_variables: dict[SystemVariableKey, Any] def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -123,13 +124,10 @@ def process(self): # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -147,7 +145,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: - extras['metadata'] = stream_response.metadata + extras["metadata"] = stream_response.metadata return ChatbotAppBlockingResponse( task_id=stream_response.task_id, @@ -158,15 +156,17 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.answer, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: @@ -176,7 +176,7 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) def _listenAudioMsg(self, publisher, task_id: str): @@ -187,17 +187,20 @@ def _listenAudioMsg(self, publisher, task_id: str): return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: @@ -228,12 +231,12 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -267,22 +270,18 @@ def _process_stream_response( db.session.close() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -293,7 +292,7 @@ def _process_stream_response( response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -304,62 +303,52 @@ def _process_stream_response( response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -372,20 +361,16 @@ def _process_stream_response( ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -399,11 +384,10 @@ def _process_stream_response( ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) yield self._error_to_stream_response(self._handle_error(err_event, self._message)) break elif isinstance(event, QueueStopEvent): @@ -420,8 +404,7 @@ def _process_stream_response( ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) # Save message @@ -434,8 +417,9 @@ def _process_stream_response( self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -445,8 +429,9 @@ def _process_stream_response( self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -472,7 +457,7 @@ def _process_stream_response( yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: @@ -502,8 +487,9 @@ def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage @@ -523,7 +509,7 @@ def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -533,15 +519,13 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata.copy() + extras["metadata"] = self._task_state.metadata.copy() - if 'annotation_reply' in extras['metadata']: - del extras['metadata']['annotation_reply'] + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _handle_output_moderation_chunk(self, text: str) -> bool: @@ -555,14 +539,11 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index f495ebbf35fe40..9040f18bfd71d3 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): """ Agent Chatbot App Config Entity. """ + agent: Optional[AgentEntity] = None class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model @@ -66,22 +70,12 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - agent=AgentConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -128,7 +122,8 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -145,13 +140,15 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # dataset configs # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -170,10 +167,7 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t :param config: app model config args """ if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -187,8 +181,9 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t if not config["agent_mode"].get("strategy"): config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [member.value for member in - list(PlanningStrategy.__members__.values())]: + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: raise ValueError("strategy in agent_mode must be in the specified strategy list") if not config["agent_mode"].get("tools"): @@ -210,7 +205,7 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t raise ValueError("enabled in agent_mode.tools must be of boolean type") if key == "dataset": - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index daf37f4a50dd31..7ba6bbab9409da 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -30,7 +30,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -39,19 +40,17 @@ def generate( @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, ) -> dict: ... - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -62,60 +61,48 @@ def generate(self, app_model: App, :param stream: is stream """ if not stream: - raise ValueError('Agent Chat App does not support blocking mode') + raise ValueError("Agent Chat App does not support blocking mode") - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -124,7 +111,7 @@ def generate(self, app_model: App, app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -145,14 +132,11 @@ def generate(self, app_model: App, invoke_from=invoke_from, extras=extras, call_depth=0, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -161,17 +145,20 @@ def generate(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -185,13 +172,11 @@ def generate(self, app_model: App, stream=stream, ) - return AgentChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( - self, flask_app: Flask, + self, + flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -224,14 +209,13 @@ def _generate_worker( pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d1bbf679c567fd..6b676b0353a525 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner): """ def run( - self, application_generate_entity: AgentChatAppGenerateEntity, + self, + application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -65,7 +66,7 @@ def run( prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -73,13 +74,10 @@ def run( # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -91,7 +89,7 @@ def run( inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -103,7 +101,7 @@ def run( app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -111,7 +109,7 @@ def run( app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -122,13 +120,13 @@ def run( message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -136,7 +134,7 @@ def run( app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -148,7 +146,7 @@ def run( app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # reorganize all inputs and template to prompt messages @@ -161,14 +159,14 @@ def run( inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: @@ -177,9 +175,9 @@ def run( agent_entity = app_config.agent # load tool variables - tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, - user_id=application_generate_entity.user_id, - tenant_id=app_config.tenant_id) + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) @@ -187,7 +185,7 @@ def run( # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, @@ -238,7 +236,7 @@ def run( prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, - model_instance=model_instance + model_instance=model_instance, ) invoke_result = runner.run( @@ -252,17 +250,21 @@ def run( invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream, - agent=True + agent=True, ) def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: """ load tool variables from database """ - tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id - ).first() + tool_variables: ToolConversationVariables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) if tool_variables: # save tool variables to session, so that we can update it later @@ -273,34 +275,40 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st conversation_id=conversation_id, user_id=user_id, tenant_id=tenant_id, - variables_str='[]', + variables_str="[]", ) db.session.add(tool_variables) db.session.commit() return tool_variables - - def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: """ convert db variables to tool variables """ - return ToolRuntimeVariablePool(**{ - 'conversation_id': db_variables.conversation_id, - 'user_id': db_variables.user_id, - 'tenant_id': db_variables.tenant_id, - 'pool': db_variables.variables - }) - - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, - message: Message) -> LLMUsage: + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) + + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: """ Get usage of all agent thoughts :param model_config: model config :param message: message :return: """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) all_message_tokens = 0 all_answer_tokens = 0 @@ -312,8 +320,5 @@ def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredenti model_type_instance = cast(LargeLanguageModel, model_type_instance) return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 118d82c495f1fe..629c309c065458 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -23,15 +23,15 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStr sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a196d36be5f1f7..73025d99d056ab 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC): _blocking_response_type: type[AppBlockingResponse] @classmethod - def convert(cls, response: Union[ - AppBlockingResponse, - Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: + def convert( + cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom + ) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: + def _generate_full_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_full_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_full_response() else: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_simple_response(response) else: + def _generate_simple_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_simple_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_simple_response() @@ -54,14 +55,16 @@ def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse @classmethod @abstractmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @abstractmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @@ -72,24 +75,26 @@ def _get_simple_metadata(cls, metadata: dict[str, Any]): :return: """ # show_retrieve_source - if 'retriever_resources' in metadata: - metadata['retriever_resources'] = [] - for resource in metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) + if "retriever_resources" in metadata: + metadata["retriever_resources"] = [] + for resource in metadata["retriever_resources"]: + metadata["retriever_resources"].append( + { + "segment_id": resource["segment_id"], + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) # show annotation reply - if 'annotation_reply' in metadata: - del metadata['annotation_reply'] + if "annotation_reply" in metadata: + del metadata["annotation_reply"] # show usage - if 'usage' in metadata: - del metadata['usage'] + if "usage" in metadata: + del metadata["usage"] return metadata @@ -101,16 +106,16 @@ def _error_to_stream_response(cls, e: Exception) -> dict: :return: """ error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, } # Determine the response based on the type of exception @@ -120,13 +125,13 @@ def _error_to_stream_response(cls, e: Exception) -> dict: data = v if data: - data.setdefault('message', getattr(e, 'description', str(e))) + data.setdefault("message", getattr(e, "description", str(e))) else: logging.error(e) data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, } return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 9e331dff4d64e2..ce6f7d43387599 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -16,10 +16,10 @@ def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_conf def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.variable} is required in input form') + raise ValueError(f"{var.variable} is required in input form") if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? - return var.default or '' + return var.default or "" if ( var.type in ( @@ -34,7 +34,7 @@ def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: - if '.' in user_input_value: + if "." in user_input_value: return float(user_input_value) else: return int(user_input_value) @@ -43,14 +43,14 @@ def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + raise ValueError(f"{var.variable} in input form must be one of the following: {options}") elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') + raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") return user_input_value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): - return value.replace('\x00', '') + return value.replace("\x00", "") return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f929a979f129de..df972756d55243 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -24,9 +24,7 @@ class PublishFrom(Enum): class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") @@ -34,9 +32,10 @@ def __init__(self, task_id: str, self._user_id = user_id self._invoke_from = invoke_from - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, - f"{user_prefix}-{self._user_id}") + user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) q = queue.Queue() @@ -66,8 +65,7 @@ def listen(self) -> Generator: # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) if elapsed_time // 10 > last_ping_time: @@ -88,9 +86,7 @@ def publish_error(self, e, pub_from: PublishFrom) -> None: :param pub_from: publish from :return: """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) + self.publish(QueueErrorEvent(error=e), pub_from) def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -122,8 +118,8 @@ def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> N if result is None: return - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": + user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": return stopped_cache_key = cls._generate_stopped_cache_key(task_id) @@ -168,9 +164,11 @@ def _check_for_sqlalchemy_models(self, data: Any): for item in data: self._check_for_sqlalchemy_models(item) else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed." + ) class GenerateTaskStoppedException(Exception): diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 60216959a85b21..aadb43ad399bbc 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -31,12 +31,15 @@ class AppRunner: - def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None) -> int: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + ) -> int: """ Get pre calculate rest tokens :param app_record: app record @@ -49,18 +52,20 @@ def get_pre_calculate_rest_tokens(self, app_record: App, """ # Invoke model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -75,36 +80,39 @@ def get_pre_calculate_rest_tokens(self, app_record: App, prompt_template_entity=prompt_template_entity, inputs=inputs, files=files, - query=query + query=query, ) - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) rest_tokens = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: - raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size.") + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage]): + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -112,27 +120,28 @@ def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, if max_tokens is None: max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: max_tokens = max(model_context_tokens - prompt_tokens, 16) for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): model_config.parameters[parameter_rule.name] = max_tokens - def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -152,60 +161,54 @@ def organize_prompt_messages(self, app_record: App, app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template - prompt_template = CompletionModelPromptTemplate( - text=advanced_completion_prompt_template.prompt - ) + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: memory_config.role_prefix = MemoryConfig.RolePrefix( user=advanced_completion_prompt_template.role_prefix.user, - assistant=advanced_completion_prompt_template.role_prefix.assistant + assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: - prompt_template.append(ChatModelMessage( - text=message.text, - role=message.role - )) + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def direct_output(self, queue_manager: AppQueueManager, - app_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list, - text: str, - stream: bool, - usage: Optional[LLMUsage] = None) -> None: + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -222,17 +225,10 @@ def direct_output(self, queue_manager: AppQueueManager, chunk = LLMResultChunk( model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=token) - ) + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), ) - queue_manager.publish( - QueueLLMChunkEvent( - chunk=chunk - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -242,15 +238,19 @@ def direct_output(self, queue_manager: AppQueueManager, model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() + usage=usage if usage else LLMUsage.empty_usage(), ), - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: AppQueueManager, - stream: bool, - agent: bool = False) -> None: + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -260,21 +260,13 @@ def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], :return: """ if not stream: - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) else: - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -285,12 +277,13 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager.publish( QueueMessageEndEvent( llm_result=invoke_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -300,21 +293,13 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, """ model = None prompt_messages = [] - text = '' + text = "" usage = None for result in invoke_result: if not agent: - queue_manager.publish( - QueueLLMChunkEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) else: - queue_manager.publish( - QueueAgentMessageEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -331,25 +316,24 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage = LLMUsage.empty_usage() llm_result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) queue_manager.publish( QueueMessageEndEvent( llm_result=llm_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) def moderation_for_inputs( - self, app_id: str, - tenant_id: str, - app_generate_entity: AppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str, + self, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -367,14 +351,17 @@ def moderation_for_inputs( tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else '', + query=query if query else "", message_id=message_id, - trace_manager=app_generate_entity.trace_manager + trace_manager=app_generate_entity.trace_manager, ) - def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - prompt_messages: list[PromptMessage]) -> bool: + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -384,8 +371,7 @@ def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGe """ hosting_moderation_feature = HostingModerationFeature() moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages ) if moderation_result: @@ -393,18 +379,20 @@ def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGe queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream + text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, ) return moderation_result - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -417,18 +405,12 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, """ external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query ) - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -440,9 +422,5 @@ def query_app_annotations_to_reply(self, app_record: App, """ annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from ) diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index a286c349b2715b..96dc7dda79af6d 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig): """ Chatbot App Config Entity. """ + pass class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> ChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model @@ -51,7 +55,7 @@ def get_app_config(cls, app_model: App, config_dict = app_model_config_dict.copy() else: if not override_config_dict: - raise Exception('override_config_dict is required when config_from is ARGS') + raise Exception("override_config_dict is required when config_from is ARGS") config_dict = override_config_dict @@ -63,19 +67,11 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -113,8 +109,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -123,7 +120,8 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -139,8 +137,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index ab15928b744b61..15c7140308d673 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -30,7 +30,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -39,7 +40,8 @@ def generate( @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -47,7 +49,8 @@ def generate( ) -> dict: ... def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -62,58 +65,46 @@ def generate( :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -122,7 +113,7 @@ def generate( app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -141,14 +132,11 @@ def generate( stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -157,17 +145,20 @@ def generate( invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -181,16 +172,16 @@ def generate( stream=stream, ) - return ChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -212,20 +203,19 @@ def _generate_worker(self, flask_app: Flask, application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 89a498eb3607f9..bd90586825d03b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -58,7 +61,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -66,13 +69,10 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -84,7 +84,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -96,7 +96,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -104,7 +104,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -115,13 +115,13 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -129,7 +129,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -141,7 +141,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -152,7 +152,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_retrieval = DatasetRetrieval(application_generate_entity) @@ -181,29 +181,26 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, files=files, query=query, context=context, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -218,7 +215,5 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 625e14c9c39712..0fa7af0a7fa36d 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -23,15 +23,15 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStr sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a7711983249a33..1193c4b7a43632 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig): """ Completion App Config Entity. """ + pass class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model @@ -51,19 +52,11 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -101,8 +94,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -114,8 +108,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c0b13b40fdb0d6..d7301224e8a2f5 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -32,7 +32,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -41,19 +42,17 @@ def generate( @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, ) -> dict: ... - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[str, None, None]]: + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -63,12 +62,12 @@ def generate(self, app_model: App, :param invoke_from: invoke from source :param stream: is stream """ - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] extras = {} @@ -76,41 +75,31 @@ def generate(self, app_model: App, conversation = None # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # get tracing instance @@ -128,14 +117,11 @@ def generate(self, app_model: App, stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -144,16 +130,19 @@ def generate(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -167,15 +156,15 @@ def generate(self, app_model: App, stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -194,20 +183,19 @@ def _generate_worker(self, flask_app: Flask, runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - message=message + message=message, ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -216,12 +204,14 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def generate_more_like_this(self, app_model: App, - message_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[str, None, None]]: + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -231,13 +221,17 @@ def generate_more_like_this(self, app_model: App, :param invoke_from: invoke from source :param stream: is stream """ - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -250,29 +244,23 @@ def generate_more_like_this(self, app_model: App, app_model_config = message.app_model_config override_model_config_dict = app_model_config.to_dict() - model_dict = override_model_config_dict['model'] - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - override_model_config_dict['model'] = model_dict + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - message.files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # init application generate entity @@ -286,14 +274,11 @@ def generate_more_like_this(self, app_model: App, user_id=user.id, stream=stream, invoke_from=invoke_from, - extras={} + extras={}, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -302,16 +287,19 @@ def generate_more_like_this(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -325,7 +313,4 @@ def generate_more_like_this(self, app_model: App, stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index f0e5f9ae173c39..da49c8701f3f48 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message) -> None: + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -54,7 +54,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # organize all inputs and template to prompt messages @@ -65,7 +65,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # moderation @@ -77,7 +77,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -85,7 +85,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -97,7 +97,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -108,7 +108,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_config = app_config.dataset @@ -126,7 +126,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, - message_id=message.id + message_id=message.id, ) # reorganize all inputs and template to prompt messages @@ -139,29 +139,26 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, inputs=inputs, files=files, query=query, - context=context + context=context, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -176,8 +173,5 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 14db74dbd04b95..697f0273a5673e 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -23,14 +23,14 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,14 +44,15 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -62,13 +63,13 @@ def convert_stream_full_response(cls, stream_response: Generator[CompletionAppSt sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -79,8 +80,9 @@ def convert_stream_full_response(cls, stream_response: Generator[CompletionAppSt yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -91,19 +93,19 @@ def convert_stream_simple_response(cls, stream_response: Generator[CompletionApp sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index fceed95b91e636..a91d48d246c5fe 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -35,23 +35,23 @@ class MessageBasedAppGenerator(BaseAppGenerator): - def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Handle response. @@ -70,7 +70,7 @@ def _handle_response( conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: @@ -82,12 +82,13 @@ def _handle_response( logger.exception(e) raise e - def _get_conversation_by_user(self, app_model: App, conversation_id: str, - user: Union[Account, EndUser]) -> Conversation: + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: conversation_filter = [ Conversation.id == conversation_id, Conversation.app_id == app_model.id, - Conversation.status == 'normal' + Conversation.status == "normal", ] if isinstance(user, Account): @@ -100,19 +101,18 @@ def _get_conversation_by_user(self, app_model: App, conversation_id: str, if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() return conversation - def _get_app_model_config(self, app_model: App, - conversation: Optional[Conversation] = None) \ - -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) if not app_model_config: raise AppModelConfigBrokenError() @@ -127,15 +127,16 @@ def _get_app_model_config(self, app_model: App, return app_model_config - def _init_generate_records(self, - application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - conversation: Optional[Conversation] = None) \ - -> tuple[Conversation, Message]: + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity @@ -148,10 +149,10 @@ def _init_generate_records(self, end_user_id = None account_id = None if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' + from_source = "api" end_user_id = application_generate_entity.user_id else: - from_source = 'console' + from_source = "console" account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -164,8 +165,11 @@ def _init_generate_records(self, model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + ]: override_model_configs = app_config.app_model_config_dict # get conversation introduction @@ -179,12 +183,12 @@ def _init_generate_records(self, model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, - name='New conversation', + name="New conversation", inputs=application_generate_entity.inputs, introduction=introduction, system_instruction="", system_instruction_tokens=0, - status='normal', + status="normal", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, @@ -216,11 +220,11 @@ def _init_generate_records(self, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, - from_account_id=account_id + from_account_id=account_id, ) db.session.add(message) @@ -232,10 +236,10 @@ def _init_generate_records(self, message_id=message.id, type=file.type.value, transfer_method=file.transfer_method.value, - belongs_to='user', + belongs_to="user", url=file.url, upload_file_id=file.related_id, - created_by_role=('account' if account_id else 'end_user'), + created_by_role=("account" if account_id else "end_user"), created_by=account_id or end_user_id, ) db.session.add(message_file) @@ -269,11 +273,7 @@ def _get_conversation(self, conversation_id: str): :param conversation_id: conversation id :return: conversation """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError() @@ -286,10 +286,6 @@ def _get_message(self, message_id: str) -> Message: :param message_id: message id :return: message """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) + message = db.session.query(Message).filter(Message.id == message_id).first() return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index f4ff44dddac9ef..7f259db6ebb55e 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -12,12 +12,9 @@ class MessageBasedAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) @@ -30,7 +27,7 @@ def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: @@ -45,17 +42,15 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueAdvancedChatMessageEndEvent): + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): raise GenerateTaskStoppedException() - diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 36d3696d601da4..8b98e74b85969b 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): """ Workflow App Config Entity. """ + pass @@ -26,13 +27,9 @@ def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -50,8 +47,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -61,9 +57,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 4347e5277bb320..c6850085779117 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -34,26 +34,28 @@ class WorkflowAppGenerator(BaseAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[True] = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> Generator[str, None, None]: ... @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> dict: ... def generate( @@ -65,7 +67,7 @@ def generate( invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ): """ Generate App response. @@ -79,26 +81,19 @@ def generate( :param call_depth: call depth :param workflow_thread_pool_id: workflow thread pool id """ - inputs = args['inputs'] + inputs = args["inputs"] # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -114,7 +109,7 @@ def generate( stream=stream, invoke_from=invoke_from, call_depth=call_depth, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -125,18 +120,19 @@ def generate( application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) def _generate( - self, *, + self, + *, app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -154,17 +150,20 @@ def _generate( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode + app_mode=app_model.mode, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'context': contextvars.copy_context(), - 'workflow_thread_pool_id': workflow_thread_pool_id - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -177,17 +176,11 @@ def _generate( stream=stream, ) - return WorkflowAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -199,16 +192,13 @@ def single_iteration_generate(self, app_model: App, :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -219,13 +209,10 @@ def single_iteration_generate(self, app_model: App, user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -235,14 +222,17 @@ def single_iteration_generate(self, app_model: App, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream + stream=stream, ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - context: contextvars.Context, - workflow_thread_pool_id: Optional[str] = None) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -259,7 +249,7 @@ def _generate_worker(self, flask_app: Flask, runner = WorkflowAppRunner( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) runner.run() @@ -267,14 +257,13 @@ def _generate_worker(self, flask_app: Flask, pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -283,14 +272,14 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool = False) -> Union[ - WorkflowAppBlockingResponse, - Generator[WorkflowAppStreamResponse, None, None] - ]: + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -306,7 +295,7 @@ def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntit workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index f448138b53c0c2..c9f501cd5e44ec 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -12,10 +12,7 @@ class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode @@ -27,19 +24,18 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: :param pub_from: :return: """ - message = WorkflowQueueMessage( - task_id=self._task_id, - app_mode=self._app_mode, - event=event - ) + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent, + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d48db7546d092..81c8463dd54a7b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - workflow_thread_pool_id: Optional[str] = None + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, ) -> None: """ :param application_generate_entity: application generate entity @@ -62,16 +62,16 @@ def run(self) -> None: app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) # if only single iteration run is requested @@ -80,10 +80,9 @@ def run(self) -> None: graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: - inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files @@ -120,12 +119,10 @@ def run(self) -> None: invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, - thread_pool_id=self.workflow_thread_pool_id + thread_pool_id=self.workflow_thread_pool_id, ) - generator = workflow_entry.run( - callbacks=workflow_callbacks - ) + generator = workflow_entry.run(callbacks=workflow_callbacks) for event in generator: self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 88bde58ba049ba..08d00ee1805aa2 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -35,8 +35,9 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking return cls.convert_blocking_full_response(blocking_response) @classmethod - def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -47,12 +48,12 @@ def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStre sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -63,8 +64,9 @@ def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStre yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -75,12 +77,12 @@ def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppSt sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 00b3b9f57e9024..904b6493811c07 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -92,7 +96,7 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id + SystemVariableKey.USER_ID: user_id, } self._task_state = WorkflowTaskState() @@ -106,16 +110,13 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr db.session.refresh(self._user) db.session.close() - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> WorkflowAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: """ To blocking response. :return: @@ -137,18 +138,19 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at) - ) + finished_at=int(stream_response.data.finished_at), + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[WorkflowAppStreamResponse, None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: """ To stream response. :return: @@ -158,10 +160,7 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) if isinstance(stream_response, WorkflowStartStreamResponse): workflow_run_id = stream_response.workflow_run_id - yield WorkflowAppStreamResponse( - workflow_run_id=workflow_run_id, - stream_response=stream_response - ) + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) def _listenAudioMsg(self, publisher, task_id: str): if not publisher: @@ -171,17 +170,20 @@ def _listenAudioMsg(self, publisher, task_id: str): return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: @@ -210,13 +212,12 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -241,22 +242,18 @@ def _process_stream_response( # init workflow run workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -267,7 +264,7 @@ def _process_stream_response( response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -278,69 +275,61 @@ def _process_stream_response( response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + outputs=json.dumps(event.outputs) + if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs + else None, conversation_id=None, trace_manager=trace_manager, ) @@ -349,22 +338,23 @@ def _process_stream_response( self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, @@ -374,8 +364,7 @@ def _process_stream_response( self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -394,7 +383,6 @@ def _process_stream_response( if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. @@ -417,7 +405,7 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" workflow_app_log.created_by = self._user.id db.session.add(workflow_app_log) @@ -431,8 +419,7 @@ def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: :return: """ response = TextChunkStreamResponse( - task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text) + task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text) ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17097268876a67..ce266116a7dfaa 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -58,89 +58,86 @@ def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ Init graph """ - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # init graph - graph = Graph.init( - graph_config=graph_config - ) + graph = Graph.init(graph_config=graph_config) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + return graph def _get_graph_and_variable_pool_of_single_iteration( - self, - workflow: Workflow, - node_id: str, - user_inputs: dict, - ) -> tuple[Graph, VariablePool]: + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: """ Get variable pool of single iteration """ # fetch workflow graph graph_config = workflow.graph_dict if not graph_config: - raise ValueError('workflow graph not found') - + raise ValueError("workflow graph not found") + graph_config = cast(dict[str, Any], graph_config) - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # filter nodes only in iteration node_configs = [ - node for node in graph_config.get('nodes', []) - if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id ] - graph_config['nodes'] = node_configs + graph_config["nodes"] = node_configs - node_ids = [node.get('id') for node in node_configs] + node_ids = [node.get("id") for node in node_configs] # filter edges only in iteration edge_configs = [ - edge for edge in graph_config.get('edges', []) - if (edge.get('source') is None or edge.get('source') in node_ids) - and (edge.get('target') is None or edge.get('target') in node_ids) + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) ] - graph_config['edges'] = edge_configs + graph_config["edges"] = edge_configs # init graph - graph = Graph.init( - graph_config=graph_config, - root_node_id=node_id - ) + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + # fetch node config from node id iteration_node_config = None for node in node_configs: - if node.get('id') == node_id: + if node.get("id") == node_id: iteration_node_config = node break if not iteration_node_config: - raise ValueError('iteration node id not found in workflow graph') - + raise ValueError("iteration node id not found in workflow graph") + # Get node class - node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) @@ -153,8 +150,7 @@ def _get_graph_and_variable_pool_of_single_iteration( try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=iteration_node_config + graph_config=workflow.graph_dict, config=iteration_node_config ) except NotImplementedError: variable_mapping = {} @@ -165,7 +161,7 @@ def _get_graph_and_variable_pool_of_single_iteration( variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get('data', {})) + node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool @@ -178,18 +174,12 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) """ if isinstance(event, GraphRunStartedEvent): self._publish_event( - QueueWorkflowStartedEvent( - graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state - ) + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) ) elif isinstance(event, GraphRunSucceededEvent): - self._publish_event( - QueueWorkflowSucceededEvent(outputs=event.outputs) - ) + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunFailedEvent): - self._publish_event( - QueueWorkflowFailedEvent(error=event.error) - ) + self._publish_event(QueueWorkflowFailedEvent(error=event.error)) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -204,7 +194,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) start_at=event.route_node_state.start_at, node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -220,14 +210,18 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result else {}, - in_iteration_id=event.in_iteration_id + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunFailedEvent): @@ -243,16 +237,18 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result - and event.route_node_state.node_run_result.error + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunStreamChunkEvent): @@ -260,14 +256,13 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) QueueTextChunkEvent( text=event.chunk_content, from_variable_selector=event.from_variable_selector, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): self._publish_event( QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources, - in_iteration_id=event.in_iteration_id + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id ) ) elif isinstance(event, ParallelBranchRunStartedEvent): @@ -277,7 +272,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunSucceededEvent): @@ -287,7 +282,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunFailedEvent): @@ -298,7 +293,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id, - error=event.error + error=event.error, ) ) elif isinstance(event, IterationRunStartedEvent): @@ -316,7 +311,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, predecessor_node_id=event.predecessor_node_id, - metadata=event.metadata + metadata=event.metadata, ) ) elif isinstance(event, IterationRunNextEvent): @@ -352,7 +347,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) outputs=event.outputs, metadata=event.metadata, steps=event.steps, - error=event.error if isinstance(event, IterationRunFailedEvent) else None + error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) @@ -371,9 +366,6 @@ def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: # return workflow return workflow - + def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 4e8f3644b1bcb2..cdd21bf7c2e94d 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -30,169 +30,145 @@ class WorkflowLoggingCallback(WorkflowCallback): - def __init__(self) -> None: self.current_node_id = None - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): - self.print_text("\n[GraphRunStartedEvent]", color='pink') + self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): - self.print_text("\n[GraphRunSucceededEvent]", color='green') + self.print_text("\n[GraphRunSucceededEvent]", color="green") elif isinstance(event, GraphRunFailedEvent): - self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): - self.on_workflow_node_execute_started( - event=event - ) + self.on_workflow_node_execute_started(event=event) elif isinstance(event, NodeRunSucceededEvent): - self.on_workflow_node_execute_succeeded( - event=event - ) + self.on_workflow_node_execute_succeeded(event=event) elif isinstance(event, NodeRunFailedEvent): - self.on_workflow_node_execute_failed( - event=event - ) + self.on_workflow_node_execute_failed(event=event) elif isinstance(event, NodeRunStreamChunkEvent): - self.on_node_text_chunk( - event=event - ) + self.on_node_text_chunk(event=event) elif isinstance(event, ParallelBranchRunStartedEvent): - self.on_workflow_parallel_started( - event=event - ) + self.on_workflow_parallel_started(event=event) elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): - self.on_workflow_parallel_completed( - event=event - ) + self.on_workflow_parallel_completed(event=event) elif isinstance(event, IterationRunStartedEvent): - self.on_workflow_iteration_started( - event=event - ) + self.on_workflow_iteration_started(event=event) elif isinstance(event, IterationRunNextEvent): - self.on_workflow_iteration_next( - event=event - ) + self.on_workflow_iteration_next(event=event) elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): - self.on_workflow_iteration_completed( - event=event - ) + self.on_workflow_iteration_completed(event=event) else: - self.print_text(f"\n[{event.__class__.__name__}]", color='blue') + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") - def on_workflow_node_execute_started( - self, - event: NodeRunStartedEvent - ) -> None: + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: """ Workflow node execute started """ - self.print_text("\n[NodeRunStartedEvent]", color='yellow') - self.print_text(f"Node ID: {event.node_id}", color='yellow') - self.print_text(f"Node Title: {event.node_data.title}", color='yellow') - self.print_text(f"Type: {event.node_type.value}", color='yellow') + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") - def on_workflow_node_execute_succeeded( - self, - event: NodeRunSucceededEvent - ) -> None: + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: """ Workflow node execute succeeded """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunSucceededEvent]", color='green') - self.print_text(f"Node ID: {event.node_id}", color='green') - self.print_text(f"Node Title: {event.node_data.title}", color='green') - self.print_text(f"Type: {event.node_type.value}", color='green') + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='green') + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green" + ) self.print_text( f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='green') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='green') + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) self.print_text( f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", - color='green') + color="green", + ) - def on_workflow_node_execute_failed( - self, - event: NodeRunFailedEvent - ) -> None: + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: """ Workflow node execute failed """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunFailedEvent]", color='red') - self.print_text(f"Node ID: {event.node_id}", color='red') - self.print_text(f"Node Title: {event.node_data.title}", color='red') - self.print_text(f"Type: {event.node_type.value}", color='red') + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Error: {node_run_result.error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='red') + self.print_text(f"Error: {node_run_result.error}", color="red") + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red" + ) self.print_text( f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='red') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='red') + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red" + ) - def on_node_text_chunk( - self, - event: NodeRunStreamChunkEvent - ) -> None: + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ Publish text chunk """ route_node_state = event.route_node_state if not self.current_node_id or self.current_node_id != route_node_state.node_id: self.current_node_id = route_node_state.node_id - self.print_text('\n[NodeRunStreamChunkEvent]') + self.print_text("\n[NodeRunStreamChunkEvent]") self.print_text(f"Node ID: {route_node_state.node_id}") node_run_result = route_node_state.node_run_result if node_run_result: self.print_text( - f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) self.print_text(event.chunk_content, color="pink", end="") - def on_workflow_parallel_started( - self, - event: ParallelBranchRunStartedEvent - ) -> None: + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: """ Publish parallel started """ - self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') - self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') - self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") if event.in_iteration_id: - self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") def on_workflow_parallel_completed( - self, - event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent ) -> None: """ Publish parallel completed """ if isinstance(event, ParallelBranchRunSucceededEvent): - color = 'blue' + color = "blue" elif isinstance(event, ParallelBranchRunFailedEvent): - color = 'red' - - self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) + color = "red" + + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) self.print_text(f"Parallel ID: {event.parallel_id}", color=color) self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) if event.in_iteration_id: @@ -201,43 +177,37 @@ def on_workflow_parallel_completed( if isinstance(event, ParallelBranchRunFailedEvent): self.print_text(f"Error: {event.error}", color=color) - def on_workflow_iteration_started( - self, - event: IterationRunStartedEvent - ) -> None: + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: """ Publish iteration started """ - self.print_text("\n[IterationRunStartedEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") - def on_workflow_iteration_next( - self, - event: IterationRunNextEvent - ) -> None: + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: """ Publish iteration next """ - self.print_text("\n[IterationRunNextEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') - self.print_text(f"Iteration Index: {event.index}", color='blue') + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") - def on_workflow_iteration_completed( - self, - event: IterationRunSucceededEvent | IterationRunFailedEvent - ) -> None: + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: """ Publish iteration completed """ - self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') - self.print_text(f"Node ID: {event.iteration_id}", color='blue') + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") - def print_text( - self, text: str, color: Optional[str] = None, end: str = "\n" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text - print(f'{text_to_print}', end=end) + print(f"{text_to_print}", end=end) def _get_colored_text(self, text: str, color: str) -> str: """Get colored text.""" diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6a1ab230416d0c..ab8d4e374e26f8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -15,13 +15,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str) -> "InvokeFrom": """ Get value of given mode. @@ -31,7 +32,7 @@ def value_of(cls, value: str) -> 'InvokeFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -40,21 +41,22 @@ def to_source(self) -> str: :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel): """ App Generate Entity. """ + task_id: str # app config @@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ Chat Application Generate Entity. """ + # app config app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity @@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Completion Application Generate Entity. """ + pass @@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Advanced Chat Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -147,15 +155,18 @@ class SingleIterationRunEntity(BaseModel): """ Single Iteration Run Entity. """ + node_id: str inputs: dict single_iteration_run: Optional[SingleIterationRunEntity] = None + class WorkflowAppGenerateEntity(AppGenerateEntity): """ Workflow Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -163,6 +174,7 @@ class SingleIterationRunEntity(BaseModel): """ Single Iteration Run Entity. """ + node_id: str inputs: dict diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 4c86b7eee13e3e..4577e28535f023 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -14,6 +14,7 @@ class QueueEvent(str, Enum): """ QueueEvent enum """ + LLM_CHUNK = "llm_chunk" TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" @@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel): """ QueueEvent abstract entity """ + event: QueueEvent @@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent): QueueLLMChunkEvent entity Only for basic mode apps """ + event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + class QueueIterationStartEvent(AppQueueEvent): """ QueueIterationStartEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_START node_execution_id: str node_id: str @@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent): predecessor_node_id: Optional[str] = None metadata: Optional[dict[str, Any]] = None + class QueueIterationNextEvent(AppQueueEvent): """ QueueIterationNextEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_NEXT index: int @@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel start node id if node is in parallel""" node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Optional[Any] = None # output for the current iteration - @field_validator('output', mode='before') + @field_validator("output", mode="before") @classmethod def set_output(cls, v): """ @@ -113,12 +120,14 @@ def set_output(cls, v): return None if isinstance(v, int | float | str | bool | dict | list): return v - raise ValueError('output must be a valid type') + raise ValueError("output must be a valid type") + class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_COMPLETED node_execution_id: str @@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent): parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" start_at: datetime - + node_run_index: int inputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None @@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity """ + event: QueueEvent = QueueEvent.TEXT_CHUNK text: str from_variable_selector: Optional[list[str]] = None @@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event: QueueEvent = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str @@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] in_iteration_id: Optional[str] = None @@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent): """ QueueAnnotationReplyEvent entity """ + event: QueueEvent = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_END llm_result: Optional[LLMResult] = None @@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): """ QueueAdvancedChatMessageEndEvent entity """ + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END @@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """ QueueWorkflowStartedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_STARTED graph_runtime_state: GraphRuntimeState @@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED outputs: Optional[dict[str, Any]] = None @@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent): """ QueueWorkflowFailedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str @@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent): """ QueueNodeStartedEvent entity """ + event: QueueEvent = QueueEvent.NODE_STARTED node_execution_id: str @@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ QueueNodeSucceededEvent entity """ + event: QueueEvent = QueueEvent.NODE_SUCCEEDED node_execution_id: str @@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity """ + event: QueueEvent = QueueEvent.NODE_FAILED node_execution_id: str @@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.AGENT_THOUGHT agent_thought_id: str @@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_FILE message_file_id: str @@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event: QueueEvent = QueueEvent.ERROR error: Any = None @@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event: QueueEvent = QueueEvent.PING @@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -365,19 +393,20 @@ def get_stop_reason(self) -> str: To stop reason """ reason_mapping = { - QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', - QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', - QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', - QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", } - return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") class QueueMessage(BaseModel): """ QueueMessage abstract entity """ + task_id: str app_mode: str event: AppQueueEvent @@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage): """ MessageQueueMessage entity """ + message_id: str conversation_id: str @@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage): """ WorkflowQueueMessage entity """ + pass @@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): """ QueueParallelBranchRunStartedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED parallel_id: str @@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): """ QueueParallelBranchRunSucceededEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED parallel_id: str @@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): """ QueueParallelBranchRunFailedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED parallel_id: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7cab6ca4e0d4bd..0135c97172b486 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -12,6 +12,7 @@ class TaskState(BaseModel): """ TaskState entity """ + metadata: dict = {} @@ -19,6 +20,7 @@ class EasyUITaskState(TaskState): """ EasyUITaskState entity """ + llm_result: LLMResult @@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState): """ WorkflowTaskState entity """ + answer: str = "" @@ -33,6 +36,7 @@ class StreamEvent(Enum): """ Stream event """ + PING = "ping" ERROR = "error" MESSAGE = "message" @@ -60,6 +64,7 @@ class StreamResponse(BaseModel): """ StreamResponse entity """ + event: StreamEvent task_id: str @@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse): """ ErrorStreamResponse entity """ + event: StreamEvent = StreamEvent.ERROR err: Exception model_config = ConfigDict(arbitrary_types_allowed=True) @@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE id: str answer: str @@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE audio: str @@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END audio: str @@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} @@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse): """ MessageFileStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_FILE id: str type: str @@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse): """ MessageReplaceStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str @@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse): """ AgentThoughtStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int @@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse): """ AgentMessageStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE id: str answer: str @@ -162,6 +176,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -182,6 +197,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -210,6 +226,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -249,7 +266,7 @@ def to_ignore_detail_dict(self): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } @@ -262,6 +279,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -315,9 +333,9 @@ def to_ignore_detail_dict(self): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } - + class ParallelBranchStartStreamResponse(StreamResponse): """ @@ -328,6 +346,7 @@ class Data(BaseModel): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -349,6 +368,7 @@ class Data(BaseModel): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -372,6 +392,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -397,6 +418,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -422,6 +444,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -454,6 +477,7 @@ class Data(BaseModel): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_CHUNK @@ -469,6 +493,7 @@ class Data(BaseModel): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_REPLACE @@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse): """ PingStreamResponse entity """ + event: StreamEvent = StreamEvent.PING @@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel): """ AppStreamResponse entity """ + stream_response: StreamResponse @@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse): """ ChatbotAppStreamResponse entity """ + conversation_id: str message_id: str created_at: int @@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse): """ CompletionAppStreamResponse entity """ + message_id: str created_at: int @@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ + workflow_run_id: Optional[str] = None @@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel): """ AppBlockingResponse entity """ + task_id: str def to_dict(self) -> dict: @@ -532,6 +563,7 @@ class Data(BaseModel): """ Data entity """ + id: str mode: str conversation_id: str @@ -552,6 +584,7 @@ class Data(BaseModel): """ Data entity """ + id: str mode: str message_id: str @@ -571,6 +604,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str status: str diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 19ff94de5e8d58..2e37a126c3eb31 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -13,11 +13,9 @@ class AnnotationReplyFeature: - def query(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -27,8 +25,9 @@ def query(self, app_record: App, :param invoke_from: invoke from :return: """ - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_record.id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) if not annotation_setting: return None @@ -41,55 +40,50 @@ def query(self, app_record: App, embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" ) dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) documents = vector.search_by_vector( - query=query, - top_k=1, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) if documents: - annotation_id = documents[0].metadata['annotation_id'] - score = documents[0].metadata['score'] + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: - from_source = 'api' + from_source = "api" else: - from_source = 'console' + from_source = "console" # insert annotation history - AppAnnotationService.add_annotation_history(annotation.id, - app_record.id, - annotation.question, - annotation.content, - query, - user_id, - message.id, - from_source, - score) + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) return annotation except Exception as e: - logger.warning(f'Query annotation failed, exception: {str(e)}.') + logger.warning(f"Query annotation failed, exception: {str(e)}.") return None return None diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index b8f3e0e1f65b74..ba14b61201e72f 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -8,8 +8,9 @@ class HostingModerationFeature: - def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list[PromptMessage]) -> bool: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -23,9 +24,6 @@ def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation( - model_config, - text - ) + moderation_result = moderation.check_moderation(model_config, text) return moderation_result diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f11e8021f0b1cc..227182f5ab0923 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict = {} - def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -27,13 +27,13 @@ def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests - if hasattr(self, 'initialized'): + if hasattr(self, "initialized"): return self.initialized = True self.client_id = client_id self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) - self.last_recalculate_time = float('-inf') + self.last_recalculate_time = float("-inf") self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): @@ -46,7 +46,7 @@ def flush_cache(self, use_local_value=False): pipe.execute() else: with redis_client.pipeline() as pipe: - self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) # flush max active requests (in-transit request list) @@ -54,8 +54,11 @@ def flush_cache(self, use_local_value=False): return request_details = redis_client.hgetall(self.active_requests_key) redis_client.expire(self.active_requests_key, timedelta(days=1)) - timeout_requests = [k for k, v in request_details.items() if - time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) @@ -69,8 +72,10 @@ def enter(self, request_id: Optional[str] = None) -> str: active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: - raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests)) + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -116,5 +121,5 @@ def close(self): if not self.closed: self.closed = True self.rate_limit.exit(self.request_id) - if self.generator is not None and hasattr(self.generator, 'close'): + if self.generator is not None and hasattr(self.generator, "close"): self.generator.close() diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 7de06dfb9639fd..652ef243b44a02 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -25,25 +25,25 @@ ) __all__ = [ - 'IntegerVariable', - 'FloatVariable', - 'ObjectVariable', - 'SecretVariable', - 'StringVariable', - 'ArrayAnyVariable', - 'Variable', - 'SegmentType', - 'SegmentGroup', - 'Segment', - 'NoneSegment', - 'NoneVariable', - 'IntegerSegment', - 'FloatSegment', - 'ObjectSegment', - 'ArrayAnySegment', - 'StringSegment', - 'ArrayStringVariable', - 'ArrayNumberVariable', - 'ArrayObjectVariable', - 'ArraySegment', + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index e6e9ce97747ce1..40a69ed4eb31f0 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -28,12 +28,12 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get('value_type')) is None: - raise VariableError('missing value type') - if not mapping.get('name'): - raise VariableError('missing name') - if (value := mapping.get('value')) is None: - raise VariableError('missing value') + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.NUMBER if isinstance(value, float): result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f'invalid number value {value}') + raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) case _: - raise VariableError(f'not supported value type {value_type}') + raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") return result @@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - raise ValueError(f'not supported value {value}') + raise ValueError(f"not supported value {value}") diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index de6c7966525c06..3c4d7046f496a9 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -4,14 +4,14 @@ from . import SegmentGroup, factory -VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") def convert_template(*, template: str, variable_pool: VariablePool): parts = re.split(VARIABLE_PATTERN, template) segments = [] for part in filter(lambda x: x, parts): - if '.' in part and (value := variable_pool.get(part.split('.'))): + if "." in part and (value := variable_pool.get(part.split("."))): segments.append(value) else: segments.append(factory.build_segment(part)) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index b4ff09b6d39ac9..b363255b2cae9e 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -8,15 +8,15 @@ class SegmentGroup(Segment): @property def text(self): - return ''.join([segment.text for segment in self.value]) + return "".join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.value]) + return "".join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.value]) + return "".join([segment.markdown for segment in self.value]) def to_object(self): return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 5c713cac6747f9..b71924b2d3f870 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -14,13 +14,13 @@ class Segment(BaseModel): value_type: SegmentType value: Any - @field_validator('value_type') + @field_validator("value_type") def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. If the value is different, a ValueError is raised. """ - if value != cls.model_fields['value_type'].default: + if value != cls.model_fields["value_type"].default: raise ValueError("Cannot modify 'value_type'") return value @@ -50,15 +50,15 @@ class NoneSegment(Segment): @property def text(self) -> str: - return 'null' + return "null" @property def log(self) -> str: - return 'null' + return "null" @property def markdown(self) -> str: - return 'null' + return "null" class StringSegment(Segment): @@ -76,24 +76,21 @@ class IntegerSegment(Segment): value: int - - - class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT value: Mapping[str, Any] @property def text(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False) + return json.dumps(self.model_dump()["value"], ensure_ascii=False) @property def log(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) @property def markdown(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) class ArraySegment(Segment): @@ -101,11 +98,11 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, 'to_markdown'): + if hasattr(item, "to_markdown"): items.append(item.to_markdown()) else: items.append(str(item)) - return '\n'.join(items) + return "\n".join(items) class ArrayAnySegment(ArraySegment): @@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index cdd2b0b4b09191..9cf0856df5d1ad 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -2,14 +2,14 @@ class SegmentType(str, Enum): - NONE = 'none' - NUMBER = 'number' - STRING = 'string' - SECRET = 'secret' - ARRAY_ANY = 'array[any]' - ARRAY_STRING = 'array[string]' - ARRAY_NUMBER = 'array[number]' - ARRAY_OBJECT = 'array[object]' - OBJECT = 'object' + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" - GROUP = 'group' + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index 8fef707fcf298b..f0e403ab8d2592 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -23,11 +23,11 @@ class Variable(Segment): """ id: str = Field( - default='', + default="", description="Unique identity for variable. It's only used by environment variables now.", ) name: str - description: str = Field(default='', description='Description of the variable.') + description: str = Field(default="", description="Description of the variable.") class StringVariable(StringSegment, Variable): @@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 2f74a180d1266c..49f58af12c9f70 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline: _task_state: TaskState _application_generate_entity: AppGenerateEntity - def __init__(self, application_generate_entity: AppGenerateEntity, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -61,18 +64,18 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non e = event.error if isinstance(e, InvokeAuthorizationError): - err = InvokeAuthorizationError('Incorrect API key provided') + err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError) or isinstance(e, ValueError): err = e else: - err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) if message: refetch_message = db.session.query(Message).filter(Message.id == message.id).first() if refetch_message: err_desc = self._error_to_desc(err) - refetch_message.status = 'error' + refetch_message.status = "error" refetch_message.error = err_desc db.session.commit() @@ -86,12 +89,14 @@ def _error_to_desc(self, e: Exception) -> str: :return: """ if isinstance(e, QuotaExceededError): - return ("Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.") + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) - message = getattr(e, 'description', str(e)) + message = getattr(e, "description", str(e)) if not message: - message = 'Internal Server Error, please contact support.' + message = "Internal Server Error, please contact support." return message @@ -101,10 +106,7 @@ def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse: :param e: exception :return: """ - return ErrorStreamResponse( - task_id=self._application_generate_entity.task_id, - err=e - ) + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) def _ping_stream_response(self) -> PingStreamResponse: """ @@ -125,11 +127,8 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: return OutputModeration( tenant_id=app_config.tenant_id, app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, ) def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: @@ -143,8 +142,7 @@ def _handle_output_moderation_when_task_finished(self, completion: str) -> Optio self._output_moderation_handler.stop_thread() completion = self._output_moderation_handler.moderation_completion( - completion=completion, - public_event=False + completion=completion, public_event=False ) self._output_moderation_handler = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8d91a507a9e8ee..61e920845c1e2e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: EasyUITaskState - _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ] - - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool) -> None: + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -101,18 +99,18 @@ def __init__(self, application_generate_entity: Union[ model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() + usage=LLMUsage.empty_usage(), ) ) self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Process generate task pipeline. @@ -125,22 +123,18 @@ def process( if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ - ChatbotAppBlockingResponse, - CompletionAppBlockingResponse - ]: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: """ Process blocking response. :return: @@ -149,11 +143,9 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = { - 'usage': jsonable_encoder(self._task_state.llm_result.usage) - } + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -164,8 +156,8 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: response = ChatbotAppBlockingResponse( @@ -177,18 +169,19 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: """ To stream response. :return: @@ -198,14 +191,14 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) yield CompletionAppStreamResponse( message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) def _listenAudioMsg(self, publisher, task_id: str): @@ -217,15 +210,19 @@ def _listenAudioMsg(self, publisher, task_id: str): return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') - if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): - publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: audio_response = self._listenAudioMsg(publisher, task_id) @@ -250,14 +247,11 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan break else: start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio.audio, - task_id=task_id) - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -333,9 +327,7 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None - ) -> None: + def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -347,31 +339,32 @@ def _save_message( self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages + self._model_config.mode, self._task_state.llm_result.prompt_messages ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' + self._message.answer = ( + PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + if llm_result.message.content + else "" + ) self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, - conversation_id=self._conversation.id, - message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id ) ) @@ -379,11 +372,9 @@ def _save_message( self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -395,22 +386,17 @@ def _handle_stop(self, event: QueueStopEvent) -> None: model = model_config.model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_instance.get_llm_num_tokens( - self._task_state.llm_result.prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_instance.get_llm_num_tokens( - [self._task_state.llm_result.message] - ) + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) credentials = model_config.credentials @@ -418,10 +404,7 @@ def _handle_stop(self, event: QueueStopEvent) -> None: model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens + model, credentials, prompt_tokens, completion_tokens ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -429,16 +412,14 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: Message end to stream response. :return: """ - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -449,9 +430,7 @@ def _agent_message_to_stream_response(self, answer: str, message_id: str) -> Age :return: """ return AgentMessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: @@ -461,9 +440,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :return: """ agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) db.session.close() @@ -478,7 +455,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op tool=agent_thought.tool, tool_labels=agent_thought.tool_labels, tool_input=agent_thought.tool_input, - message_files=agent_thought.files + message_files=agent_thought.files, ) return None @@ -500,15 +477,15 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), ) - ), PublishFrom.TASK_PIPELINE + ), + PublishFrom.TASK_PIPELINE, ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 8ff50dd1747c0a..011daba6878ffa 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -30,10 +30,7 @@ class MessageCycleManage: _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] _task_state: Union[EasyUITaskState, WorkflowTaskState] @@ -49,15 +46,18 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) if auto_generate_conversation_name and is_first_message: # start generate thread - thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'conversation_id': conversation.id, - 'query': query - }) + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation.id, + "query": query, + }, + ) thread.start() @@ -65,17 +65,10 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> return None - def _generate_conversation_name_worker(self, - flask_app: Flask, - conversation_id: str, - query: str): + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: return @@ -105,12 +98,9 @@ def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, } return annotation @@ -124,7 +114,7 @@ def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> No :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata['retriever_resources'] = event.retriever_resources + self._task_state.metadata["retriever_resources"] = event.retriever_resources def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ @@ -132,27 +122,23 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti :param event: event :return: """ - message_file = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() if message_file: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url to local file - if message_file.url.startswith('http'): + if message_file.url.startswith("http"): url = message_file.url else: url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) @@ -161,8 +147,8 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or 'user', - url=url + belongs_to=message_file.belongs_to or "user", + url=url, ) return None @@ -174,11 +160,7 @@ def _message_to_stream_response(self, answer: str, message_id: str) -> MessageSt :param message_id: message id :return: """ - return MessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer - ) + return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: """ @@ -186,7 +168,4 @@ def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStre :param answer: answer :return: """ - return MessageReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - answer=answer - ) + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ed3225310a8ec4..a030d5dcbf3c05 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -70,14 +70,14 @@ def _handle_workflow_run_start(self) -> WorkflowRun: inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): - if key.value == 'conversation': + if key.value == "conversation": continue - inputs[f'sys.{key.value}'] = value + inputs[f"sys.{key.value}"] = value inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from= ( + triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN @@ -185,20 +185,26 @@ def _handle_workflow_run_failed( db.session.commit() - running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value - ).all() + running_workflow_node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + .all() + ) for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + workflow_node_execution.elapsed_time = ( + workflow_node_execution.finished_at - workflow_node_execution.created_at + ).total_seconds() db.session.commit() db.session.refresh(workflow_run) @@ -216,7 +222,9 @@ def _handle_workflow_run_failed( return workflow_run - def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + def _handle_node_execution_start( + self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: # init workflow node execution workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.tenant_id = workflow_run.tenant_id @@ -333,16 +341,16 @@ def _workflow_finish_to_stream_response( created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - 'id': created_by_account.id, - 'name': created_by_account.name, - 'email': created_by_account.email, + "id": created_by_account.id, + "name": created_by_account.name, + "email": created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - 'id': created_by_end_user.id, - 'user': created_by_end_user.session_id, + "id": created_by_end_user.id, + "user": created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -401,7 +409,7 @@ def _workflow_node_start_to_stream_response( # extras logic if event.node_type == NodeType.TOOL: node_data = cast(ToolNodeData, event.node_data) - response.data.extras['icon'] = ToolManager.get_tool_icon( + response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, provider_id=node_data.provider_id, @@ -410,10 +418,10 @@ def _workflow_node_start_to_stream_response( return response def _workflow_node_finish_to_stream_response( - self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. @@ -424,7 +432,7 @@ def _workflow_node_finish_to_stream_response( """ if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: return None - + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -452,13 +460,10 @@ def _workflow_node_finish_to_stream_response( iteration_id=event.in_iteration_id, ), ) - + def _workflow_parallel_branch_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunStartedEvent - ) -> ParallelBranchStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: """ Workflow parallel branch start to stream response :param task_id: task id @@ -476,15 +481,15 @@ def _workflow_parallel_branch_start_to_stream_response( parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, created_at=int(time.time()), - ) + ), ) - + def _workflow_parallel_branch_finished_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent - ) -> ParallelBranchFinishedStreamResponse: + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: """ Workflow parallel branch finished to stream response :param task_id: task id @@ -501,18 +506,15 @@ def _workflow_parallel_branch_finished_to_stream_response( parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, - status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, created_at=int(time.time()), - ) + ), ) def _workflow_iteration_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueIterationStartEvent - ) -> IterationNodeStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: """ Workflow iteration start to stream response :param task_id: task id @@ -534,10 +536,12 @@ def _workflow_iteration_start_to_stream_response( metadata=event.metadata or {}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + def _workflow_iteration_next_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: """ Workflow iteration next to stream response :param task_id: task id @@ -559,10 +563,12 @@ def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + def _workflow_iteration_completed_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: """ Workflow iteration completed to stream response :param task_id: task id @@ -585,13 +591,13 @@ def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflo status=WorkflowNodeExecutionStatus.SUCCEEDED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), - total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: @@ -643,7 +649,7 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() @@ -656,11 +662,10 @@ def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() if not workflow_run: - raise Exception(f'Workflow run not found: {workflow_run_id}') + raise Exception(f"Workflow run not found: {workflow_run_id}") return workflow_run @@ -683,6 +688,6 @@ def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNo ) if not workflow_node_execution: - raise Exception(f'Workflow node execution not found: {node_execution_id}') + raise Exception(f"Workflow node execution not found: {node_execution_id}") - return workflow_node_execution \ No newline at end of file + return workflow_node_execution diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 578996574739a8..99e992fd89063e 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -16,31 +16,32 @@ "red": "31;1", } + def get_colored_text(text: str, color: str) -> str: """Get colored text.""" color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: file.flush() # ensure all printed content are written to file + class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = '' + + color: Optional[str] = "" current_loop: int = 1 def __init__(self, color: Optional[str] = None) -> None: super().__init__() """Initialize callback handler.""" # use a specific color is not specified - self.color = color or 'green' + self.color = color or "green" self.current_loop = 1 def on_tool_start( @@ -58,7 +59,7 @@ def on_tool_end( tool_outputs: Sequence[ToolInvokeMessage], message_id: Optional[str] = None, timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> None: """If not the final action, print out observation.""" print_text("\n[on_tool_end]\n", color=self.color) @@ -79,26 +80,21 @@ def on_tool_end( ) ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start( - self, thought: str - ) -> None: + def on_agent_start(self, thought: str) -> None: """Run on agent start.""" if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ - "\nThought: " + thought + "\n", color=self.color) + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish( - self, color: Optional[str] = None, **kwargs: Any - ) -> None: + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: """Run on agent end.""" print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) @@ -107,9 +103,9 @@ def on_agent_finish( @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8e1f496b226c14..50cde18c54f7e9 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ - from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -11,11 +10,9 @@ class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: AppQueueManager, - app_id: str, - message_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id @@ -29,11 +26,12 @@ def on_query(self, query: str, dataset_id: str) -> None: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=self._app_id, - created_by_role=('account' - if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self._user_id + created_by_role=( + "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + ), + created_by=self._user_id, ) db.session.add(dataset_query) @@ -43,18 +41,15 @@ def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() @@ -64,26 +59,25 @@ def return_retriever_resource_info(self, resource: list): for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get('position'), - dataset_id=item.get('dataset_id'), - dataset_name=item.get('dataset_name'), - document_id=item.get('document_id'), - document_name=item.get('document_name'), - data_source_type=item.get('data_source_type'), - segment_id=item.get('segment_id'), - score=item.get('score') if 'score' in item else None, - hit_count=item.get('hit_count') if 'hit_count' else None, - word_count=item.get('word_count') if 'word_count' in item else None, - segment_position=item.get('segment_position') if 'segment_position' in item else None, - index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, - content=item.get('content'), - retriever_from=item.get('retriever_from'), - created_by=self._user_id + position=item.get("position"), + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, ) db.session.add(dataset_retriever_resource) db.session.commit() self._queue_manager.publish( - QueueRetrieverResourcesEvent(retriever_resources=resource), - PublishFrom.APPLICATION_MANAGER + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 84bab7e1a3d22f..8ac12f72f29d6c 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -2,4 +2,4 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" \ No newline at end of file + """Callback Handler that prints to std out.""" diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index b7e0cc0c2b2ae6..4cc793b0d76d96 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -29,9 +29,13 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider).first() + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + ) + .first() + ) if embedding: text_embeddings[i] = embedding.get_embedding() else: @@ -41,17 +45,18 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, - self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) for i in range(0, len(embedding_queue_texts), max_chunks): - batch_texts = embedding_queue_texts[i:i + max_chunks] + batch_texts = embedding_queue_texts[i : i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user) for vector in embedding_result.embeddings: try: @@ -60,16 +65,18 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: except IntegrityError: db.session.rollback() except Exception as e: - logging.exception('Failed transform embedding: ', e) + logging.exception("Failed transform embedding: ", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): text_embeddings[i] = embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider, + ) embedding_cache.set_embedding(embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) @@ -78,7 +85,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: db.session.rollback() except Exception as ex: db.session.rollback() - logger.error('Failed to embed documents: ', ex) + logger.error("Failed to embed documents: ", ex) raise ex return text_embeddings @@ -87,16 +94,13 @@ def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) try: - embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user) embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() @@ -116,6 +120,6 @@ def embed_query(self, text: str) -> list[float]: except IntegrityError: db.session.rollback() except: - logging.exception('Failed to add embedding to redis') + logging.exception("Failed to add embedding to redis") return embedding_results diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 0cdf8670c492c2..656bf4aa724893 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -2,7 +2,7 @@ class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 370aeee4633550..10bc9f6ed7d12b 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -5,7 +5,7 @@ class PromptMessageFileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel): class ImagePromptMessageFile(PromptMessageFile): class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 22a21ecf9331ea..9ed5528e43b9b8 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -12,6 +12,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -40,7 +42,7 @@ def __init__(self, provider_entity: ProviderEntity) -> None: label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel): """ Model class for model response. """ + status: ModelStatus load_balancing_enabled: bool = False @@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ + provider: SimpleModelProviderEntity @@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 778ef2e1ac42ad..4797b69b8596bb 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType @@ -67,9 +68,13 @@ def __init__(self, **data): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any(len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: @@ -83,10 +88,9 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: - raise ValueError(f'Model {model} is disabled.') + raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] @@ -99,10 +103,12 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: @@ -128,20 +134,21 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -161,7 +168,8 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: @@ -171,17 +179,21 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [] ) if provider_record: @@ -189,9 +201,7 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } + original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: @@ -207,8 +217,7 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, - credentials=credentials + provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -239,15 +248,13 @@ def add_or_update_custom_credentials(self, credentials: dict) -> None: provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() @@ -260,12 +267,15 @@ def delete_custom_credentials(self) -> None: :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # delete provider if provider_record: @@ -277,13 +287,14 @@ def delete_custom_credentials(self) -> None: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -305,13 +316,15 @@ def get_custom_model_credentials(self, model_type: ModelType, model: str, obfusc return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [], ) return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -321,24 +334,29 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [] ) if provider_model_record: try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) except JSONDecodeError: original_credentials = {} @@ -350,10 +368,7 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): @@ -388,7 +403,7 @@ def add_or_update_custom_model_credentials(self, model_type: ModelType, model: s model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_model_record) db.session.commit() @@ -396,7 +411,7 @@ def add_or_update_custom_model_credentials(self, model_type: ModelType, model: s provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -409,13 +424,16 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # delete provider model if provider_model_record: @@ -425,7 +443,7 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -437,13 +455,16 @@ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSettin :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = True @@ -455,7 +476,7 @@ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSettin provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=True + enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -469,13 +490,16 @@ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetti :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = False @@ -487,7 +511,7 @@ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetti provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=False + enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -501,13 +525,16 @@ def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optio :param model: model name :return: """ - return db.session.query(ProviderModelSetting) \ + return ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -516,24 +543,30 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov :param model: model name :return: """ - load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).count() + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) if load_balancing_config_count <= 1: - raise ValueError('Model load balancing configuration must be more than 1.') + raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = True @@ -545,7 +578,7 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=True + load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -559,13 +592,16 @@ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> Pro :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = False @@ -577,7 +613,7 @@ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> Pro provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=False + load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -617,11 +653,14 @@ def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: return # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value @@ -629,7 +668,7 @@ def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value + preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) @@ -658,9 +697,7 @@ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: lis :return: """ # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - credential_form_schemas - ) + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() @@ -670,9 +707,9 @@ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: lis return copy_credentials - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -688,8 +725,9 @@ def get_provider_model(self, model_type: ModelType, return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -711,15 +749,11 @@ def get_provider_models(self, model_type: Optional[ModelType] = None, if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: @@ -728,11 +762,12 @@ def get_provider_models(self, model_type: Optional[ModelType] = None, # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -760,7 +795,7 @@ def _get_system_provider_models(self, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -783,23 +818,20 @@ def _get_system_provider_models(self, if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL]: + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name + copy_credentials["base_model_name"] = restrict_model.base_model_name try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -809,8 +841,10 @@ def _get_system_provider_models(self, continue status = ModelStatus.ACTIVE - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -825,7 +859,7 @@ def _get_system_provider_models(self, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -839,11 +873,12 @@ def _get_system_provider_models(self, return provider_models - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -885,7 +920,7 @@ def _get_custom_provider_models(self, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -895,15 +930,13 @@ def _get_custom_provider_models(self, continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -911,8 +944,10 @@ def _get_custom_provider_models(self, status = ModelStatus.ACTIVE load_balancing_enabled = False - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -931,7 +966,7 @@ def _get_custom_provider_models(self, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs - model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0d5b0a1b2c6ba6..44725623dc4bd4 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -8,18 +8,19 @@ class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" class SystemConfigurationStatus(Enum): """ Enum class for system configuration status. """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" class RestrictModel(BaseModel): @@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. """ + quota_type: ProviderQuotaType quota_unit: QuotaUnit quota_limit: int @@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel): """ Model class for provider system configuration. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] @@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel): """ Class for model load balancing configuration. """ + id: str name: str credentials: dict @@ -93,6 +100,7 @@ class ModelSettings(BaseModel): """ Model class for model settings. """ + model: str model_type: ModelType enabled: bool = True diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 53323a2eeb67ce..3b186476ebe977 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -3,6 +3,7 @@ class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ def __init__(self, description: Optional[str] = None) -> None: class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception): """ Custom exception raised when the quota for an app has been exceeded. """ + description = "App Invoke Quota Exceeded" @@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" class InvokeRateLimitError(Exception): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4db7a999736c5b..38cebb6b6b1c36 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -20,10 +20,7 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: :param params: the request params :return: the response json """ - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer {}".format(self.api_key) - } + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} url = self.api_endpoint @@ -32,20 +29,17 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: proxies = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxies = { - 'http': dify_config.SSRF_PROXY_HTTP_URL, - 'https': dify_config.SSRF_PROXY_HTTPS_URL, + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, } response = requests.request( - method='POST', + method="POST", url=url, - json={ - 'point': point.value, - 'params': params - }, + json={"point": point.value, "params": params}, headers=headers, timeout=self.timeout, - proxies=proxies + proxies=proxies, ) except requests.exceptions.Timeout: raise ValueError("request timeout") @@ -53,9 +47,8 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format( - response.status_code, - response.text[:100] - )) + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 8d73aa2b8bab8c..f1a49c4921906e 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -11,8 +11,8 @@ class ExtensionModule(enum.Enum): - MODERATION = 'moderation' - EXTERNAL_DATA_TOOL = 'external_data_tool' + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" class ModuleExtension(BaseModel): @@ -41,12 +41,12 @@ def scan_extensions(cls): position_map = {} # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") current_dir_path = os.path.dirname(current_path) # traverse subdirectories for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith('__'): + if subdir_name.startswith("__"): continue subdir_path = os.path.join(current_dir_path, subdir_name) @@ -58,21 +58,21 @@ def scan_extensions(cls): # in the front-end page and business logic, there are special treatments. builtin = False position = None - if '__builtin__' in file_names: + if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, '__builtin__') + builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding='utf-8') as f: + with open(builtin_file_path, encoding="utf-8") as f: position = int(f.read().strip()) position_map[extension_name] = position - if (extension_name + '.py') not in file_names: + if (extension_name + ".py") not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + '.py') + py_path = os.path.join(subdir_path, extension_name + ".py") spec = importlib.util.spec_from_file_location(extension_name, py_path) if not spec or not spec.loader: raise Exception(f"Failed to load module {extension_name} from {py_path}") @@ -91,25 +91,29 @@ def scan_extensions(cls): json_data = {} if not builtin: - if 'schema.json' not in file_names: + if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, 'schema.json') + json_path = os.path.join(subdir_path, "schema.json") json_data = {} if os.path.exists(json_path): - with open(json_path, encoding='utf-8') as f: + with open(json_path, encoding="utf-8") as f: json_data = json.load(f) - extensions.append(ModuleExtension( - extension_class=extension_class, - name=extension_name, - label=json_data.get('label'), - form_schema=json_data.get('form_schema'), - builtin=builtin, - position=position - )) - - sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) + + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 29e892c58ac550..3da170455e3398 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -6,10 +6,7 @@ class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} - module_classes = { - ExtensionModule.MODERATION: Moderation, - ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool - } + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} def init(self): for module, module_class in self.module_classes.items(): diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 58c82502ea4447..54ec97a4933a94 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -30,10 +30,11 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: raise ValueError("api_based_extension_id is required") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -50,47 +51,42 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: api_based_extension_id = self.config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == self.tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(self.variable)) + raise ValueError( + "[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid".format(self.variable) + ) # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=self.tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) try: # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( - self.variable, - e - )) - - response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ - 'app_id': self.app_id, - 'tool_variable': self.variable, - 'inputs': inputs, - 'query': query - }) - - if 'result' not in response_json: - raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" - .format(self.variable)) - - if not isinstance(response_json['result'], str): - raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" - .format(self.variable)) - - return response_json['result'] + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) + + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) + + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) + + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) + + return response_json["result"] diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 8601cb34e79582..84b94e117ff5f9 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -12,11 +12,14 @@ class ExternalDataFetch: - def fetch(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -38,7 +41,7 @@ def fetch(self, tenant_id: str, app_id, tool, inputs, - query + query, ) futures[future] = tool @@ -50,12 +53,15 @@ def fetch(self, tenant_id: str, inputs.update(results) return inputs - def _query_external_data_tool(self, flask_app: Flask, - tenant_id: str, - app_id: str, - external_data_tool: ExternalDataVariableEntity, - inputs: dict, - query: str) -> tuple[Optional[str], Optional[str]]: + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: dict, + query: str, + ) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app @@ -72,17 +78,10 @@ def _query_external_data_tool(self, flask_app: Flask, tool_config = external_data_tool.config external_data_tool_factory = ExternalDataToolFactory( - name=tool_type, - tenant_id=tenant_id, - app_id=app_id, - variable=tool_variable, - config=tool_config + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config ) # query external data tool - result = external_data_tool_factory.query( - inputs=inputs, - query=query - ) + result = external_data_tool_factory.query(inputs=inputs, query=query) return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 979f243af65f61..28721098594962 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -5,14 +5,10 @@ class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( - tenant_id=tenant_id, - app_id=app_id, - variable=variable, - config=config + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 3959f4b4a0bb61..5c4e694025ea73 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel): """ File Upload Entity. """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -28,9 +29,9 @@ def value_of(value): class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" @staticmethod def value_of(value): @@ -39,9 +40,10 @@ def value_of(value): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" @staticmethod def value_of(value): @@ -65,16 +67,16 @@ class FileVar(BaseModel): def to_dict(self) -> dict: return { - '__variant': self.__class__.__name__, - 'tenant_id': self.tenant_id, - 'type': self.type.value, - 'transfer_method': self.transfer_method.value, - 'url': self.preview_url, - 'remote_url': self.url, - 'related_id': self.related_id, - 'filename': self.filename, - 'extension': self.extension, - 'mime_type': self.mime_type, + "__variant": self.__class__.__name__, + "tenant_id": self.tenant_id, + "type": self.type.value, + "transfer_method": self.transfer_method.value, + "url": self.preview_url, + "remote_url": self.url, + "related_id": self.related_id, + "filename": self.filename, + "extension": self.extension, + "mime_type": self.mime_type, } def to_markdown(self) -> str: @@ -86,7 +88,7 @@ def to_markdown(self) -> str: if self.type == FileType.IMAGE: text = f'![{self.filename or ""}]({preview_url})' else: - text = f'[{self.filename or preview_url}]({preview_url})' + text = f"[{self.filename or preview_url}]({preview_url})" return text @@ -115,28 +117,29 @@ def prompt_message_content(self) -> ImagePromptMessageContent: return ImagePromptMessageContent( data=self.data, detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW + if image_config.get("detail") == "high" + else ImagePromptMessageContent.DETAIL.LOW, ) def _get_data(self, force_url: bool = False) -> Optional[str]: from models.model import UploadFile + if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.related_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url + upload_file = ( + db.session.query(UploadFile) + .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) + .first() ) + + return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) elif self.transfer_method == FileTransferMethod.TOOL_FILE: extension = self.extension # add sign url - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=extension + ) return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 085ff07cfde921..8feaabedbbdf45 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -13,13 +13,13 @@ class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, - user: Union[Account, EndUser]) -> list[FileVar]: + def validate_and_transform_files_arg( + self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] + ) -> list[FileVar]: """ validate and transform files arg @@ -30,22 +30,22 @@ def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], f """ for file in files: if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): - raise ValueError('Missing file tool_file_id') + raise ValueError("Invalid file format, must be dict") + if not file.get("type"): + raise ValueError("Missing file type") + FileType.value_of(file.get("type")) + if not file.get("transfer_method"): + raise ValueError("Missing file transfer method") + FileTransferMethod.value_of(file.get("transfer_method")) + if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: + if not file.get("url"): + raise ValueError("Missing file url") + if not file.get("url").startswith("http"): + raise ValueError("Invalid file url") + if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): + raise ValueError("Missing file upload_file_id") + if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): + raise ValueError("Missing file tool_file_id") # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -62,17 +62,17 @@ def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], f continue # Validate number of files - if len(files) > image_config['number_limits']: + if len(files) > image_config["number_limits"]: raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") for file_obj in file_objs: # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') + if file_obj.transfer_method.value not in image_config["transfer_methods"]: + raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") # Validate file type if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') + raise ValueError(f"Invalid file type: {file_obj.type}") if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: # check remote url valid and is image @@ -81,18 +81,21 @@ def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], f raise ValueError(error) elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) + upload_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.related_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + UploadFile.extension.in_(IMAGE_EXTENSIONS), + ) + .first() + ) # check upload file is belong to tenant and user if not upload_file: - raise ValueError('Invalid upload file') + raise ValueError("Invalid upload file") new_files.append(file_obj) @@ -113,8 +116,9 @@ def transform_message_files(self, files: list[MessageFile], file_extra_config: F # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: + def _to_file_objs( + self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig + ) -> dict[FileType, list[FileVar]]: """ transform files to file objs @@ -152,23 +156,23 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileEx :return: """ if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) if transfer_method != FileTransferMethod.TOOL_FILE: return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config + url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config, ) return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, url=None, - related_id=file.get('tool_file_id'), - extra_config=file_extra_config + related_id=file.get("tool_file_id"), + extra_config=file_extra_config, ) else: return FileVar( @@ -178,7 +182,7 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileEx transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, related_id=file.upload_file_id or None, - extra_config=file_extra_config + extra_config=file_extra_config, ) def _check_image_remote_url(self, url): @@ -190,17 +194,17 @@ def _check_image_remote_url(self, url): def is_s3_presigned_url(url): try: parsed_url = urlparse(url) - if 'amazonaws.com' not in parsed_url.netloc: + if "amazonaws.com" not in parsed_url.netloc: return False query_params = parse_qs(parsed_url.query) - required_params = ['Signature', 'Expires'] + required_params = ["Signature", "Expires"] for param in required_params: if param not in query_params: return False - if not query_params['Expires'][0].isdigit(): + if not query_params["Expires"][0].isdigit(): return False - signature = query_params['Signature'][0] - if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): + signature = query_params["Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): return False return True except Exception: diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac577e3a..1efaf5529db83f 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,8 +1,7 @@ -tool_file_manager = { - 'manager': None -} +tool_file_manager = {"manager": None} + class ToolFileParser: @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + def get_tool_file_manager() -> "ToolFileManager": + return tool_file_manager["manager"] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 737a11e426c745..a8c1fd4d02d01e 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -9,7 +9,7 @@ from configs import dify_config from extensions.ext_storage import storage -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) @@ -22,18 +22,18 @@ def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 try: data = storage.load(upload_file.key) except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') + logging.error(f"File not found: {upload_file.key}") return None - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' + encoded_string = base64.b64encode(data).decode("utf-8") + return f"data:{upload_file.mime_type};base64,{encoded_string}" @classmethod def get_signed_temp_image_url(cls, upload_file_id) -> str: @@ -44,7 +44,7 @@ def get_signed_temp_image_url(cls, upload_file_id) -> str: :return: """ base_url = dify_config.FILES_URL - image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' + image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4662ebb47a56a5..4a80a3ffe92207 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -15,9 +15,11 @@ logger = logging.getLogger(__name__) + class CodeExecutionException(Exception): pass + class CodeExecutionResponse(BaseModel): class Data(BaseModel): stdout: Optional[str] = None @@ -29,9 +31,9 @@ class Data(BaseModel): class CodeLanguage(str, Enum): - PYTHON3 = 'python3' - JINJA2 = 'jinja2' - JAVASCRIPT = 'javascript' + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" class CodeExecutor: @@ -45,63 +47,65 @@ class CodeExecutor: } code_language_to_running_language = { - CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = { - CodeLanguage.PYTHON3 - } + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod - def execute_code(cls, - language: CodeLanguage, - preload: str, - code: str) -> str: + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" - headers = { - 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY - } + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} data = { - 'language': cls.code_language_to_running_language.get(language), - 'code': code, - 'preload': preload, - 'enable_network': True + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, } try: - response = post(str(url), json=data, headers=headers, - timeout=Timeout( - connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, - read=dify_config.CODE_EXECUTION_READ_TIMEOUT, - write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, - pool=None)) + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) if response.status_code == 503: - raise CodeExecutionException('Code execution service is unavailable') + raise CodeExecutionException("Code execution service is unavailable") elif response.status_code != 200: - raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') + raise Exception( + f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" + ) except CodeExecutionException as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, which is likely a network issue,' - ' please check if the sandbox service is running.' - f' ( Error: {str(e)} )') + raise CodeExecutionException( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) try: response = response.json() except: - raise CodeExecutionException('Failed to parse response') + raise CodeExecutionException("Failed to parse response") - if (code := response.get('code')) != 0: + if (code := response.get("code")) != 0: raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) @@ -109,7 +113,7 @@ def execute_code(cls, if response.data.error: raise CodeExecutionException(response.data.error) - return response.data.stdout or '' + return response.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: @@ -122,7 +126,7 @@ def execute_workflow_code_template(cls, language: CodeLanguage, code: str, input """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f'Unsupported language {language}') + raise CodeExecutionException(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 3f099b7ac5bbb4..e233a596b9da0e 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -26,23 +26,9 @@ def get_default_config(cls) -> dict: return { "type": "code", "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], "code_language": cls.get_language(), "code": cls.get_default_code(), - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } + "outputs": {"result": {"type": "string", "children": None}}, + }, } diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index a157fcc6d147cd..ae324b83a95124 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -18,4 +18,5 @@ def get_default_code(cls) -> str: result: arg1 + arg2 } } - """) + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index a4d2551972e3a4..d67a0903aa4d4c 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -21,5 +21,6 @@ def get_runner_script(cls) -> str: var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """) + """ + ) return runner_script diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index f1e5da584c660d..db2eb5ebb6b19a 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -10,8 +10,6 @@ def format(cls, template: str, inputs: dict) -> str: :param inputs: inputs :return: """ - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=inputs - ) + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - return result['result'] + return result["result"] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index b8cb29600e106e..63d58edbc794e9 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -11,9 +11,7 @@ def transform_response(cls, response: str) -> dict: :param response: response :return: """ - return { - 'result': cls.extract_result_str_from_response(response) - } + return {"result": cls.extract_result_str_from_response(response)} @classmethod def get_runner_script(cls) -> str: diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 923724b49d8a2d..9cca8af7c698bc 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -17,4 +17,5 @@ def main(arg1: str, arg2: str) -> dict: return { "result": arg1 + arg2, } - """) + """ + ) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf66558b658f57..6f016f27bc874d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,9 +5,9 @@ class TemplateTransformer(ABC): - _code_placeholder: str = '{{code}}' - _inputs_placeholder: str = '{{inputs}}' - _result_tag: str = '<>' + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" @classmethod def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: @@ -24,9 +24,9 @@ def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: @classmethod def extract_result_str_from_response(cls, response: str) -> str: - result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError('Failed to parse result') + raise ValueError("Failed to parse result") result = result.group(1) return result @@ -50,7 +50,7 @@ def get_runner_script(cls) -> str: @classmethod def serialize_inputs(cls, inputs: dict) -> str: inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() - input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded @classmethod @@ -67,4 +67,4 @@ def get_preload_script(cls) -> str: """ Get preload script """ - return '' + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 5e5deb86b47e54..96341a1b780a80 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -8,14 +8,15 @@ def obfuscated_token(token: str): if not token: return token if len(token) <= 8: - return '*' * 20 - return token[:6] + '*' * 12 + token[-2:] + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): - raise ValueError(f'Tenant with id {tenant_id} not found') + raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 29cb4acc7d03c2..5e274f8916869d 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -25,7 +25,7 @@ def get(self) -> Optional[dict]: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 20feae8554f79d..b880590de28476 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -12,19 +12,20 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config - if (moderation_config and moderation_config.enabled is True - and 'openai' in hosting_configuration.provider_map - and hosting_configuration.provider_map['openai'].enabled is True + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider - if using_provider_type == ProviderType.SYSTEM \ - and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map['openai'] + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] # 2000 text per chunk length = 2000 - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] if len(text_chunks) == 0: return True @@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() moderation_result = model_type_instance.invoke( - model='text-moderation-stable', - credentials=hosting_openai_config.credentials, - text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: return True except Exception as ex: logger.exception(ex) - raise InvokeBadRequestError('Rate limit exceeded, please try again later.') + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 2000577a406e6f..e6e149154870da 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] """ Get all the subclasses of the parent type from the module """ - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] return classes @@ -56,6 +57,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 32e3806231858b..3efdc8aa471697 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -73,10 +73,10 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) def is_filtered( - include_set: set[str], - exclude_set: set[str], - data: Any, - name_func: Callable[[Any], str], + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], ) -> bool: """ Check if the object should be filtered out. @@ -102,9 +102,9 @@ def is_filtered( def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -117,13 +117,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 14ca8e943c71fc..4e6d58904e90eb 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,31 +1,34 @@ """ Proxy requests to avoid SSRF """ + import logging import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') -SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') -SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) +SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") +SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") +SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = { - 'http://': SSRF_PROXY_HTTP_URL, - 'https://': SSRF_PROXY_HTTPS_URL -} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +proxies = ( + {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} + if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL + else None +) BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - + retries = 0 while retries <= max_retries: try: @@ -52,24 +55,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('GET', url, max_retries=max_retries, **kwargs) + return make_request("GET", url, max_retries=max_retries, **kwargs) def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('POST', url, max_retries=max_retries, **kwargs) + return make_request("POST", url, max_retries=max_retries, **kwargs) def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PUT', url, max_retries=max_retries, **kwargs) + return make_request("PUT", url, max_retries=max_retries, **kwargs) def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PATCH', url, max_retries=max_retries, **kwargs) + return make_request("PATCH", url, max_retries=max_retries, **kwargs) def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('DELETE', url, max_retries=max_retries, **kwargs) + return make_request("DELETE", url, max_retries=max_retries, **kwargs) def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('HEAD', url, max_retries=max_retries, **kwargs) + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index a6f486e81de006..4c3b736186383d 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,14 +9,11 @@ class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType, - identity_id: str - ): + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" def get(self) -> Optional[dict]: @@ -28,7 +25,7 @@ def get(self) -> Optional[dict]: cached_tool_parameter = redis_client.get(self.cache_key) if cached_tool_parameter: try: - cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = cached_tool_parameter.decode("utf-8") cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None @@ -52,4 +49,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6880c..94b02cf98578b1 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -9,6 +9,7 @@ class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -22,7 +23,7 @@ def get(self) -> Optional[dict]: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None @@ -46,4 +47,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ddcd7512861da9..eeeccc234917ff 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -46,7 +46,7 @@ class HostingConfiguration: def init_app(self, app: Flask) -> None: config = app.config - if config.get('EDITION') != 'CLOUD': + if config.get("EDITION") != "CLOUD": return self.provider_map["azure_openai"] = self.init_azure_openai(config) @@ -65,7 +65,7 @@ def init_azure_openai(app_config: Config) -> HostingProvider: credentials = { "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" + "base_model_name": "gpt-35-turbo", } quotas = [] @@ -77,26 +77,45 @@ def init_azure_openai(app_config: Config) -> HostingProvider: RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM), RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM), RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel( + model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM + ), RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] + RestrictModel( + model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM + ), + RestrictModel( + model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM + ), + RestrictModel( + model="text-embedding-ada-002", + base_model_name="text-embedding-ada-002", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-small", + base_model_name="text-embedding-3-small", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-large", + base_model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], ) quotas.append(trial_quota) - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -110,17 +129,12 @@ def init_openai(self, app_config: Config) -> HostingProvider: if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -134,12 +148,7 @@ def init_openai(self, app_config: Config) -> HostingProvider: if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -153,9 +162,7 @@ def init_anthropic(app_config: Config) -> HostingProvider: if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): @@ -170,12 +177,7 @@ def init_anthropic(app_config: Config) -> HostingProvider: if app_config.get("HOSTED_ANTHROPIC_API_BASE"): credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -192,7 +194,7 @@ def init_minimax(app_config: Config) -> HostingProvider: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -210,7 +212,7 @@ def init_spark(app_config: Config) -> HostingProvider: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -228,7 +230,7 @@ def init_zhipuai(app_config: Config) -> HostingProvider: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -238,21 +240,19 @@ def init_zhipuai(app_config: Config) -> HostingProvider: @staticmethod def init_moderation_config(app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): + if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') + enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") ) - return HostedModerationConfig( - enabled=False - ) + return HostedModerationConfig(enabled=False) @staticmethod def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: models_str = app_config.get(env_var) models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - + return [ + RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) + for model_name in models_list + if model_name.strip() + ] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index df563f609b3022..b6968e46cdc2a9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -39,7 +39,6 @@ class IndexingRunner: - def __init__(self): self.storage = storage self.model_manager = ModelManager() @@ -49,25 +48,26 @@ def run(self, dataset_documents: list[DatasetDocument]): for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) @@ -76,20 +76,20 @@ def run(self, dataset_documents: list[DatasetDocument]): index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, - documents=documents + documents=documents, ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: - logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) + logging.warning("Document deleted, document id: {}".format(dataset_document.id)) except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -98,26 +98,25 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() for document_segment in document_segments: db.session.delete(document_segment) db.session.commit() # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -125,28 +124,26 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) # load self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -155,17 +152,14 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] @@ -180,42 +174,48 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) documents.append(document) # build index # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate( + self, + tenant_id: str, + extract_settings: list[ExtractSetting], + tmp_processing_rule: dict, + doc_form: str = None, + doc_language: str = "English", + dataset_id: str = None, + indexing_technique: str = "economy", + ) -> dict: """ Estimate the indexing for the document. """ @@ -229,18 +229,16 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + raise ValueError("Dataset not found.") + if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -248,7 +246,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -263,8 +261,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) + mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter @@ -272,9 +269,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin # split to documents documents = self._split_to_documents_for_estimate( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule + text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) @@ -282,110 +277,110 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin if len(preview_texts) < 5: preview_texts.append(document.page_content) - if doc_form and doc_form == 'qa_model': - + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) + response = LLMGenerator.generate_qa_document( + current_user.current_tenant_id, preview_texts[0], doc_language + ) document_qa_list = self.format_split_text(response) - return { - "total_segments": total_segments * 20, - "qa_preview": document_qa_list, - "preview": preview_texts - } - return { - "total_segments": total_segments, - "preview": preview_texts - } + return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} + return {"total_segments": total_segments, "preview": preview_texts} - def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ - -> list[Document]: + def _extract( + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + ) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == 'upload_file': - if not data_source_info or 'upload_file_id' not in data_source_info: + if dataset_document.data_source_type == "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + ) if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=dataset_document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'notion_import': - if (not data_source_info or 'notion_workspace_id' not in data_source_info - or 'notion_page_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], "document": dataset_document, - "tenant_id": dataset_document.tenant_id + "tenant_id": dataset_document.tenant_id, }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'website_crawl': - if (not data_source_info or 'provider' not in data_source_info - or 'url' not in data_source_info or 'job_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], "tenant_id": dataset_document.tenant_id, - "url": data_source_info['url'], - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata['document_id'] = dataset_document.id - text_doc.metadata['dataset_id'] = dataset_document.dataset_id + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @staticmethod def filter_string(text): - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) return text @staticmethod - def _get_splitter(processing_rule: DatasetProcessRule, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -399,10 +394,10 @@ def _get_splitter(processing_rule: DatasetProcessRule, separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") - if segmentation.get('chunk_overlap'): - chunk_overlap = segmentation['chunk_overlap'] + if segmentation.get("chunk_overlap"): + chunk_overlap = segmentation["chunk_overlap"] else: chunk_overlap = 0 @@ -411,22 +406,27 @@ def _get_splitter(processing_rule: DatasetProcessRule, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter - def _step_split(self, text_docs: list[Document], splitter: TextSplitter, - dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> list[Document]: + def _step_split( + self, + text_docs: list[Document], + splitter: TextSplitter, + dataset: Dataset, + dataset_document: DatasetDocument, + processing_rule: DatasetProcessRule, + ) -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -436,14 +436,12 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language + document_language=dataset_document.doc_language, ) # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -457,7 +455,7 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -465,15 +463,21 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) return documents - def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> list[Document]: + def _split_to_documents( + self, + text_docs: list[Document], + splitter: TextSplitter, + processing_rule: DatasetProcessRule, + tenant_id: str, + document_form: str, + document_language: str, + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -488,12 +492,11 @@ def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, documents = splitter.split_documents([text_doc]) split_documents = [] for document_node in documents: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -506,15 +509,21 @@ def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, split_documents.append(document_node) all_documents.extend(split_documents) # processing qa document - if document_form == 'qa_model': + if document_form == "qa_model": for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': document_language}) + document_format_thread = threading.Thread( + target=self.format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": tenant_id, + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": document_language, + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -533,12 +542,14 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy()) + qa_document = Document( + page_content=result["question"], metadata=document_node.metadata.model_copy() + ) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -546,8 +557,9 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> list[Document]: + def _split_to_documents_for_estimate( + self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -567,8 +579,8 @@ def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) - document.metadata['doc_id'] = doc_id - document.metadata['doc_hash'] = hash + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -586,23 +598,23 @@ def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - if 'pre_processing_rules' in rules: + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text @@ -611,27 +623,26 @@ def format_split_text(text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] - def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, - dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + dataset_document: DatasetDocument, + documents: list[Document], + ) -> None: """ insert index and update document/segment status to completed """ embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # chunk nodes by chunk size @@ -640,18 +651,27 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, chunk_size = 10 # create keyword index - create_keyword_thread = threading.Thread(target=self._process_keyword_index, - args=(current_app._get_current_object(), - dataset.id, dataset_document.id, documents)) + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) create_keyword_thread.start() - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance)) + chunk_documents = documents[i : i + chunk_size] + futures.append( + executor.submit( + self._process_chunk, + current_app._get_current_object(), + index_processor, + chunk_documents, + dataset, + dataset_document, + embedding_model_instance, + ) + ) for future in futures: tokens += future.result() @@ -668,7 +688,7 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, - } + }, ) @staticmethod @@ -679,23 +699,26 @@ def _process_keyword_index(flask_app, dataset_id, document_id, documents): raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != 'high_quality': - document_ids = [document.metadata['doc_id'] for document in documents] + if dataset.indexing_technique != "high_quality": + document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() - def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance): + def _process_chunk( + self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + ): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -703,26 +726,26 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d tokens = 0 if embedding_model_instance: tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens( - [document.page_content] - ) + embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) for document in chunk_documents ) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata['doc_id'] for document in chunk_documents] + document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() @@ -730,14 +753,15 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d @staticmethod def _check_document_paused_status(document_id: str): - indexing_cache_key = 'document_{}_is_paused'.format(document_id) + indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() @staticmethod - def _update_document_index_status(document_id: str, after_indexing_status: str, - extra_update_params: Optional[dict] = None) -> None: + def _update_document_index_status( + document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + ) -> None: """ Update the document indexing status. """ @@ -748,9 +772,7 @@ def _update_document_index_status(document_id: str, after_indexing_status: str, if not document: raise DocumentIsDeletedPausedException() - update_params = { - DatasetDocument.indexing_status: after_indexing_status - } + update_params = {DatasetDocument.indexing_status: after_indexing_status} if extra_update_params: update_params.update(extra_update_params) @@ -780,7 +802,7 @@ def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index @@ -788,17 +810,23 @@ def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) - def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: + def _transform( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + text_docs: list[Document], + doc_language: str, + process_rule: dict, + ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -806,18 +834,20 @@ def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule, tenant_id=dataset.tenant_id, - doc_language=doc_language) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=process_rule, + tenant_id=dataset.tenant_id, + doc_language=doc_language, + ) return documents def _load_segments(self, dataset, dataset_document, documents): # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -831,7 +861,7 @@ def _load_segments(self, dataset, dataset_document, documents): extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -839,8 +869,8 @@ def _load_segments(self, dataset, dataset_document, documents): dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c13b4a45cbe6c..78a6d6e6836d39 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -43,21 +43,16 @@ def generate_conversation_name( with measure_time() as timer: response = model_instance.invoke_llm( - prompt_messages=prompts, - model_parameters={ - "max_tokens": 100, - "temperature": 1 - }, - stream=False + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ) answer = response.message.content - cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) result_dict = json.loads(cleaned_answer) - answer = result_dict['Your Output'] + answer = result_dict["Your Output"] name = answer.strip() if len(name) > 75: - name = name[:75] + '...' + name = name[:75] + "..." # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) @@ -79,14 +74,9 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - template="{{histories}}\n{{format_instructions}}\nquestions:\n" - ) + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({ - "histories": histories, - "format_instructions": format_instructions - }) + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: model_manager = ModelManager() @@ -101,12 +91,7 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - "max_tokens": 256, - "temperature": 0 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False ) questions = output_parser.parse(response.message.content) @@ -119,32 +104,24 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" - rule_config = { - "prompt": "", - "variables": [], - "opening_statement": "", - "error": "" - } - model_parameters = { - "max_tokens": rule_config_max_tokens, - "temperature": 0.01 - } + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} if no_variable: - prompt_template = PromptTemplateParser( - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE - ) + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate)] @@ -158,13 +135,11 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.content - + except InvokeError as e: error = str(e) error_step = "generate rule config" @@ -179,24 +154,18 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - prompt_generate - ) + prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser( - parameter_generate - ) + parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser( - statement_generate - ) + statement_template = PromptTemplateParser(statement_generate) # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] @@ -213,9 +182,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: # the first step to generate the task prompt prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -230,7 +197,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di inputs={ "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] @@ -240,15 +207,13 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di "TASK_DESCRIPTION": instruction, "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) except InvokeError as e: @@ -257,9 +222,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False ) rule_config["opening_statement"] = statement_content.message.content except InvokeError as e: @@ -284,18 +247,10 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): model_type=ModelType.LLM, ) - prompt_messages = [ - SystemPromptMessage(content=prompt), - UserPromptMessage(content=query) - ] + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - 'temperature': 0.01, - "max_tokens": 2000 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False ) answer = response.message.content diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 8856f0c6856952..b6932698cbf462 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -10,9 +10,12 @@ class RuleConfigGeneratorOutputParser: - def get_format_instructions(self) -> tuple[str, str, str]: - return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) def parse(self, text: str) -> Any: try: @@ -21,16 +24,9 @@ def parse(self, text: str) -> Any: if not isinstance(parsed["prompt"], str): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): - raise ValueError( - "Expected 'variables' to be a list." - ) + raise ValueError("Expected 'variables' to be a list.") if not isinstance(parsed["opening_statement"], str): - raise ValueError( - "Expected 'opening_statement' to be a str." - ) + raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException( - f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" - ) - + raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 3f046c68fceaf0..182aeed98fd7ff 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -6,7 +6,6 @@ class SuggestedQuestionsAfterAnswerOutputParser: - def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -15,7 +14,7 @@ def parse(self, text: str) -> Any: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index dbd6e26c7c260c..7ab257872f54c6 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -66,19 +66,19 @@ "and keeping each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" "The output must be an array in JSON format following the specified schema:\n" - "[\"question1\",\"question2\",\"question3\"]\n" + '["question1","question2","question3"]\n' ) GENERATOR_QA_PROMPT = ( - ' The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step.' - 'Step 1: Understand and summarize the main content of this text.\n' - 'Step 2: What key information or concepts are mentioned in this text?\n' - 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' - 'Step 4: Generate questions and answers based on these key information and concepts.\n' - ' The questions should be clear and detailed, and the answers should be detailed and complete. ' - 'You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n' - ' Use the following format: Q1:\nA1:\nQ2:\nA2:...\n' - '' + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" ) WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index b33d4dd7cb342c..54b1d8212bd619 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -21,8 +21,9 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages(self, max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> list[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> list[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -31,16 +32,11 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, app_record = self.conversation.app # fetch limited messages, and return reversed - query = db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id - ).filter( - Message.conversation_id == self.conversation.id, - Message.answer != '' - ).order_by(Message.created_at.desc()) + query = ( + db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id) + .filter(Message.conversation_id == self.conversation.id, Message.answer != "") + .order_by(Message.created_at.desc()) + ) if message_limit and message_limit > 0: message_limit = message_limit if message_limit <= 500 else 500 @@ -50,10 +46,7 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, messages = query.limit(message_limit).all() messages = list(reversed(messages)) - message_file_parser = MessageFileParser( - tenant_id=app_record.tenant_id, - app_id=app_record.id - ) + message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() @@ -63,20 +56,17 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == message.workflow_run_id).first()) + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) if workflow_run: file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, - is_vision=False + workflow_run.workflow.features_dict, is_vision=False ) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] @@ -97,24 +87,23 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, return [] # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages)>1: + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ Get history prompt text. :param human_prefix: human prefix @@ -123,10 +112,7 @@ def get_history_prompt_text(self, human_prefix: str = "Human", :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit - ) + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index bba004a32a21d6..92da53c9a464df 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -18,12 +18,21 @@ class Callback: Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -39,10 +48,19 @@ def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -59,10 +77,19 @@ def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -79,10 +106,19 @@ def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -99,9 +135,7 @@ def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, cred """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88b9c9..3b6b825244dfdc 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,11 +10,20 @@ logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +37,49 @@ def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +97,19 @@ def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +124,33 @@ def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') - - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') - - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") + + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") + + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +165,5 @@ def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, cred :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 175c13cfdcc04c..659ad59bd67f91 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index e04d9fcbbbdc4a..e94be6f9181cd1 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,123 +2,123 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', - }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", + }, + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', - }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', - }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", + }, + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", + }, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_K: { - 'label': { - 'en_US': 'Top K', - 'zh_Hans': 'Top K', - }, - 'type': 'int', - 'help': { - 'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.', - 'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。', - }, - 'required': False, - 'default': 50, - 'min': 1, - 'max': 100, - 'precision': 0, + "label": { + "en_US": "Top K", + "zh_Hans": "Top K", + }, + "type": "int", + "help": { + "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", + "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", + }, + "required": False, + "default": 50, + "min": 1, + "max": 100, + "precision": 0, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", + }, + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", + }, + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', - }, - 'type': 'int', - 'help': { - 'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.', - 'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。', - }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", + }, + "type": "int", + "help": { + "en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.", + "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", + }, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], + "required": False, + "options": ["JSON", "XML"], }, DefaultParameterName.JSON_SCHEMA: { - 'label': { - 'en_US': 'JSON Schema', + "label": { + "en_US": "JSON Schema", }, - 'type': 'text', - 'help': { - 'en_US': 'Set a response json schema will ensure LLM to adhere it.', - 'zh_Hans': '设置返回的json schema,llm将按照它返回', + "type": "text", + "help": { + "en_US": "Set a response json schema will ensure LLM to adhere it.", + "zh_Hans": "设置返回的json schema,llm将按照它返回", }, - 'required': False, + "required": False, }, } diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 59a4c103a20702..52b590f66a3873 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -12,11 +12,12 @@ class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +27,14 @@ def value_of(cls, value: str) -> 'LLMMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,20 +52,20 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) - def plus(self, other: 'LLMUsage') -> 'LLMUsage': + def plus(self, other: "LLMUsage") -> "LLMUsage": """ Add two LLMUsage instances together. @@ -85,10 +87,10 @@ def plus(self, other: 'LLMUsage') -> 'LLMUsage': total_tokens=self.total_tokens + other.total_tokens, total_price=self.total_price + other.total_price, currency=other.currency, - latency=self.latency + other.latency + latency=self.latency + other.latency, ) - def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + def __add__(self, other: "LLMUsage") -> "LLMUsage": """ Overload the + operator to add two LLMUsage instances. @@ -97,10 +99,12 @@ def __add__(self, other: 'LLMUsage') -> 'LLMUsage': """ return self.plus(other) + class LLMResult(BaseModel): """ Model class for llm result. """ + model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -112,6 +116,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -122,6 +127,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -132,4 +138,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e8e6963b56d7a7..e51bb18debefe4 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -9,13 +9,14 @@ class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ def value_of(cls, value: str) -> 'PromptMessageRole': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,16 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT @@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ + class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,7 +135,7 @@ class ToolCallFunction(BaseModel): type: str function: ToolCallFunction - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -145,10 +157,12 @@ def is_empty(self) -> bool: return True + class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d6377d7e88a283..d898ef149092f6 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,22 @@ def value_of(cls, origin_model_type: str) -> "ModelType": :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +51,28 @@ def to_origin_model_type(self) -> str: :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +81,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -89,6 +93,7 @@ class DefaultParameterName(str, Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" TOP_K = "top_k" @@ -99,7 +104,7 @@ class DefaultParameterName(str, Enum): JSON_SCHEMA = "json_schema" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -109,13 +114,14 @@ def value_of(cls, value: Any) -> 'DefaultParameterName': for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -127,6 +133,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -144,6 +151,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -158,6 +166,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -175,6 +184,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -185,6 +195,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -197,6 +208,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -205,6 +217,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f88f89d5886332..bfe861a97ffbf8 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -12,6 +12,7 @@ class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -20,6 +21,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -39,6 +42,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -46,15 +50,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -110,6 +117,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -138,7 +146,7 @@ def to_simple_provider(self) -> SimpleProviderEntity: icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -146,5 +154,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/core/model_runtime/entities/rerank_entities.py index d51efd2b3be133..99709e1bcd2127 100644 --- a/api/core/model_runtime/entities/rerank_entities.py +++ b/api/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def3791333..846b89d6580b18 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 0513cfaf67b216..edfb19c7d07d4c 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ def __str__(self): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 8db79a52bb612a..7fcd2133f9f8d1 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 716bb63566c372..09d2d7e54da9e1 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -66,12 +66,14 @@ def _transform_invoke_error(self, error: Exception) -> InvokeError: :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. " + ) return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") @@ -115,7 +117,7 @@ def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) return PriceInfo( unit_price=unit_price, @@ -136,24 +138,26 @@ def predefined_models(self) -> list[AIModelEntity]: model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -165,10 +169,10 @@ def predefined_models(self) -> list[AIModelEntity]: yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) + default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"]) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) @@ -176,31 +180,26 @@ def predefined_models(self) -> list[AIModelEntity]: except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] - } + if "label" not in parameter_rule: + parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]} new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] - } + if "label" not in yaml_data: + yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]} - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}" + ) # cache model schema model_schemas.append(model_schema) @@ -235,7 +234,9 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: Mapping + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -261,19 +262,19 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if not parameter_rule.default and "default" in default_parameter_rule: + parameter_rule.default = default_parameter_rule["default"] + if not parameter_rule.precision and "precision" in default_parameter_rule: + parameter_rule.precision = default_parameter_rule["precision"] + if not parameter_rule.required and "required" in default_parameter_rule: + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + en_US=default_parameter_rule["help"]["en_US"], ) if ( parameter_rule.help diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index cfc8942c794f36..5c39186e6504e6 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,7 +77,7 @@ def invoke(self, model: str, credentials: dict, callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -82,7 +90,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: @@ -96,7 +104,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -111,7 +119,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -127,7 +135,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( @@ -140,15 +148,23 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -183,7 +199,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) model_parameters.pop("response_format") @@ -195,15 +211,16 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", str(prompt_messages[0].content)) + content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.") + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last text message @@ -216,9 +233,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message break else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -228,33 +243,30 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) else: return self._code_block_mode_stream_processor( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -303,16 +315,13 @@ def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[Pr prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -378,18 +387,23 @@ def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_mes prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: """ Invoke result generator @@ -397,9 +411,7 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d :return: result generator """ callbacks = callbacks or [] - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -418,7 +430,7 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -438,7 +450,7 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d prompt_messages=prompt_messages, message=prompt_message, usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -447,15 +459,21 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -472,8 +490,13 @@ def _invoke(self, model: str, credentials: dict, raise NotImplementedError @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -519,7 +542,9 @@ def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> L return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -539,10 +564,7 @@ def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -558,16 +580,23 @@ def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger before invoke callbacks @@ -593,7 +622,7 @@ def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -601,11 +630,19 @@ def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, else: logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger new chunk callbacks @@ -632,7 +669,7 @@ def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, creden tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -640,11 +677,19 @@ def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, creden else: logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger after invoke callbacks @@ -672,7 +717,7 @@ def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credent tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -680,11 +725,19 @@ def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credent else: logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger invoke error callbacks @@ -712,7 +765,7 @@ def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -758,11 +811,13 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") @@ -775,16 +830,19 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di else: if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places." + ) # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): raise ValueError(f"Model Parameter {parameter_name} should be bool.") diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 780460a3f738fe..4374093de4ab38 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -29,32 +29,32 @@ def validate_provider_credentials(self, credentials: dict) -> None: def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = load_yaml_file(yaml_path) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -92,15 +92,15 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') + model_type_name = model_type.value.replace("-", "_") model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception(f"Invalid model type {model_type} for provider {provider_name}") # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 2b17f292c5db00..d04414ccb87a63 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -12,14 +12,13 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -37,9 +36,7 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke large language model @@ -50,4 +47,3 @@ def _invoke(self, model: str, credentials: dict, :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 2c86f25180eab8..5fb96047425592 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,10 +44,16 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 4fb11025fe07fd..b6b0b737436d9c 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -12,14 +12,13 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -35,9 +34,7 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -59,4 +56,4 @@ def _get_demo_file_path(self) -> str: current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index e0f1adb1c47f23..a5810e2f0e4b09 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,14 +11,15 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -36,9 +37,9 @@ def invoke(self, model: str, credentials: dict, prompt: str, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 381d2f6cd19ed4..54a44860236918 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model @@ -38,9 +39,9 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f5619685..5fe6dda6ad5d79 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,28 @@ _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 2dfd323a47a27f..70be9322a75d53 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -15,13 +15,15 @@ class TTSModel(AIModel): """ Model class for TTS model. """ + model_type: ModelType = ModelType.TTS # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -35,14 +37,21 @@ def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: st :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -71,10 +80,13 @@ def get_tts_model_voices(self, model: str, credentials: dict, language: Optional if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if - language and language in d.get('language')] + return [ + {"name": d["name"], "value": d["mode"]} + for d in voices + if language and language in d.get("language") + ] else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ @@ -123,23 +135,23 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): match = re.compile(pattern) tx = match.finditer(org_text) start = 0 result = [] - one_sentence = '' + one_sentence = "" for i in tx: end = i.regs[0][1] tmp = org_text[start:end] if len(one_sentence + tmp) > max_length: result.append(one_sentence) - one_sentence = '' + one_sentence = "" one_sentence += tmp start = end last_sens = org_text[start:] if last_sens: one_sentence += last_sens - if one_sentence != '': + if one_sentence != "": result.append(one_sentence) return result diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index 325c6c060ef6c8..5b12f04a3e59b8 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `claude-3-opus-20240229` model for validate, - model_instance.validate_credentials( - model='claude-3-opus-20240229', - credentials=credentials - ) + model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 81be1a06a7cd0e..30e9d2e9f2e0d1 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -55,11 +55,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -76,10 +82,17 @@ def _invoke(self, model: str, credentials: dict, # invoke model return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -96,41 +109,39 @@ def _chat_generate(self, model: str, credentials: dict, credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample") # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get('max_tokens') > 4096: + if model_parameters.get("max_tokens") > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" if tools: - extra_model_kwargs['tools'] = [ - self._transform_tool_prompt(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( model=model, messages=prompt_message_dicts, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: # chat model @@ -140,22 +151,30 @@ def _chat_generate(self, model: str, credentials: dict, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if model_parameters.get('response_format'): + if model_parameters.get("response_format"): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -167,24 +186,27 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: - return { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.parameters - } - - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters} + + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -197,22 +219,30 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -228,9 +258,9 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr tokens = client.count_tokens(prompt) tool_call_inner_prompts_tokens_map = { - 'claude-3-opus-20240229': 395, - 'claude-3-haiku-20240307': 264, - 'claude-3-sonnet-20240229': 159 + "claude-3-opus-20240229": 395, + "claude-3-haiku-20240307": 264, + "claude-3-sonnet-20240229": 159, } if model in tool_call_inner_prompts_tokens_map and tools: @@ -257,13 +287,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Union[Message, ToolsBetaMessage], + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -274,22 +309,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[]) for content in response.content: - if content.type == 'text': + if content.type == "text": assistant_prompt_message.content += content.text - elif content.type == 'tool_use': + elif content.type == "tool_use": tool_call = AssistantPromptMessage.ToolCall( id=content.id, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content.name, - arguments=json.dumps(content.input) - ) + name=content.name, arguments=json.dumps(content.input) + ), ) assistant_prompt_message.tool_calls.append(tool_call) @@ -308,17 +339,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm chat stream response @@ -327,7 +355,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -338,24 +366,23 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, for chunk in response: if isinstance(chunk, MessageStartEvent): - if hasattr(chunk, 'content_block'): + if hasattr(chunk, "content_block"): content_block = chunk.content_block if isinstance(content_block, dict): - if content_block.get('type') == 'tool_use': + if content_block.get("type") == "tool_use": tool_call = AssistantPromptMessage.ToolCall( - id=content_block.get('id'), - type='function', + id=content_block.get("id"), + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content_block.get('name'), - arguments='' - ) + name=content_block.get("name"), arguments="" + ), ) tool_calls.append(tool_call) - elif hasattr(chunk, 'delta'): + elif hasattr(chunk, "delta"): delta = chunk.delta if isinstance(delta, dict) and len(tool_calls) > 0: - if delta.get('type') == 'input_json_delta': - tool_calls[-1].function.arguments += delta.get('partial_json', '') + if delta.get("type") == "input_json_delta": + tool_calls[-1].function.arguments += delta.get("partial_json", "") elif chunk.message: return_model = chunk.message.model input_tokens = chunk.message.usage.input_tokens @@ -369,29 +396,24 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, # transform empty tool call arguments to {} for tool_call in tool_calls: if not tool_call.function.arguments: - tool_call.function.arguments = '{}' + tool_call.function.arguments = "{}" yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -401,7 +423,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -412,14 +434,14 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if credentials.get('anthropic_api_url'): - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if credentials.get("anthropic_api_url"): + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/") + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs @@ -452,10 +474,7 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -465,25 +484,25 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) @@ -492,34 +511,28 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl content = [] if message.tool_calls: for tool_call in message.tool_calls: - content.append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + content.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + ) if message.content: - content.append({ - "type": "text", - "text": message.content - }) - + content.append({"type": "text", "text": message.content}) + if prompt_message_dicts[-1]["role"] == "assistant": prompt_message_dicts[-1]["content"].extend(content) else: - prompt_message_dicts.append({ - "role": "assistant", - "content": content - }) + prompt_message_dicts.append({"role": "assistant", "content": content}) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [ + {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content} + ], } prompt_message_dicts.append(message_dict) else: @@ -576,16 +589,13 @@ def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) - :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -601,24 +611,14 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError - ], - InvokeAuthorizationError: [ - anthropic.AuthenticationError, - anthropic.PermissionDeniedError - ], + InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], + InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 31c788d226db34..32a0269af49314 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,10 @@ class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +29,14 @@ def _to_credential_kwargs(credentials: dict) -> dict: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index f4f7d964ef38cd..c2744691c3ed6e 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,12 @@ PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,51 +54,47 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -112,37 +109,37 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-0125', + base_model_name="gpt-35-turbo-0125", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -157,51 +154,47 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -216,32 +209,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -249,34 +239,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -291,32 +277,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -324,34 +307,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-0125-preview', + base_model_name="gpt-4-0125-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -366,32 +345,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -399,34 +375,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -441,32 +413,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -474,34 +443,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini', + base_model_name="gpt-4o-mini", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -517,32 +482,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -550,34 +512,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini-2024-07-18', + base_model_name="gpt-4o-mini-2024-07-18", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -593,32 +551,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -626,46 +581,40 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o', + base_model_name="gpt-4o", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -681,32 +630,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -714,34 +660,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-05-13', + base_model_name="gpt-4o-2024-05-13", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -757,32 +699,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -790,34 +729,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-08-06', + base_model_name="gpt-4o-2024-08-06", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -833,32 +768,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -866,46 +798,40 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo', + base_model_name="gpt-4-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -921,32 +847,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -954,34 +877,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo-2024-04-09', + base_model_name="gpt-4-turbo-2024-04-09", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -997,32 +916,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -1030,39 +946,33 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, - features=[ - ModelFeature.VISION - ], + features=[ModelFeature.VISION], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, @@ -1070,32 +980,29 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -1103,34 +1010,30 @@ class AzureBaseModel(BaseModel): max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-35-turbo-instruct", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1140,19 +1043,19 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1161,16 +1064,16 @@ class AzureBaseModel(BaseModel): input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1180,19 +1083,19 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1201,20 +1104,18 @@ class AzureBaseModel(BaseModel): input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] EMBEDDING_BASE_MODELS = [ AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="text-embedding-ada-002", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1224,17 +1125,15 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1244,17 +1143,15 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1264,135 +1161,129 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py index 68977b2266718d..2e3c6aab0588ec 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -6,6 +6,5 @@ class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c0c782e42b4a9b..2ad2289869bbae 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -34,16 +34,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - base_model_name = credentials.get('base_model_name') + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -56,7 +60,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -67,7 +71,7 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) def get_num_tokens( @@ -75,14 +79,14 @@ def get_num_tokens( model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: - raise ValueError(f'Base Model Name {base_model_name} is invalid') + raise ValueError(f"Base Model Name {base_model_name} is invalid") model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -92,21 +96,21 @@ def get_num_tokens( # text completion model, do not support tool calling content = prompt_messages[0].content assert isinstance(content, str) - return self._num_tokens_from_string(credentials,content) + return self._num_tokens_from_string(credentials, content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise CredentialsValidateFailedError('Base Model Name is required') + raise CredentialsValidateFailedError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -118,7 +122,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -127,7 +131,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -137,33 +141,35 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -172,15 +178,12 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] ): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -209,24 +212,21 @@ def _handle_generate_response( return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] ) -> Generator: - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -254,8 +254,8 @@ def _handle_generate_stream_response( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -265,14 +265,20 @@ def _handle_generate_stream_response( delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") @@ -293,7 +299,7 @@ def _chat_generate(self, model: str, credentials: dict, extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] # extra_model_kwargs['functions'] = [{ # "name": tool.name, # "description": tool.description, @@ -301,10 +307,10 @@ def _chat_generate(self, model: str, credentials: dict, # } for tool in tools] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -322,9 +328,12 @@ def _chat_generate(self, model: str, credentials: dict, return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( - self, model: str, credentials: dict, response: ChatCompletion, + self, + model: str, + credentials: dict, + response: ChatCompletion, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): assistant_message = response.choices[0].message assistant_message_tool_calls = assistant_message.tool_calls @@ -334,10 +343,7 @@ def _handle_chat_generate_response( self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -369,13 +375,13 @@ def _handle_chat_generate_stream_response( credentials: dict, response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): index = 0 - full_assistant_content = '' + full_assistant_content = "" real_model = model system_fingerprint = None - completion = '' + completion = "" tool_calls = [] for chunk in response: if len(chunk.choices) == 0: @@ -386,7 +392,6 @@ def _handle_chat_generate_stream_response( if delta.delta is None: continue - # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) @@ -396,15 +401,14 @@ def _handle_chat_generate_stream_response( # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content if delta.delta.content else "" yield LLMResultChunk( model=real_model, @@ -413,7 +417,7 @@ def _handle_chat_generate_stream_response( delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 0 @@ -421,9 +425,7 @@ def _handle_chat_generate_stream_response( # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - full_assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + full_assistant_prompt_message = AssistantPromptMessage(content=completion) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage @@ -434,27 +436,24 @@ def _handle_chat_generate_stream_response( prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), ) @staticmethod - def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: if tool_calls_response: for response_tool_call in tool_calls_response: if isinstance(response_tool_call, ChatCompletionMessageToolCall): function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) elif isinstance(response_tool_call, ChoiceDeltaToolCall): @@ -463,8 +462,10 @@ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_c tool_calls[index].id = response_tool_call.id or tool_calls[index].id tool_calls[index].type = response_tool_call.type or tool_calls[index].type if response_tool_call.function: - tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name - tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" else: assert response_tool_call.id is not None assert response_tool_call.type is not None @@ -473,13 +474,10 @@ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_c assert response_tool_call.function.arguments is not None function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) @@ -495,19 +493,13 @@ def _convert_prompt_message_to_dict(message: PromptMessage): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -525,7 +517,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage): "role": "tool", "name": message.name, "content": message.content, - "tool_call_id": message.tool_call_id + "tool_call_id": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -535,10 +527,11 @@ def _convert_prompt_message_to_dict(message: PromptMessage): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -550,14 +543,13 @@ def _num_tokens_from_string(self, credentials: dict, text: str, return num_tokens def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -591,10 +583,10 @@ def _num_tokens_from_messages( # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -626,40 +618,39 @@ def _num_tokens_from_messages( @staticmethod def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters['title'])) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters['type'])) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters['properties'].items(): + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e40b6a..a2b14cf3dbe6d4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -40,7 +38,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -65,10 +63,9 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> return response.text def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e073bef0149486..d9cff8ecbbadb9 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -16,19 +16,18 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +43,9 @@ def _invoke(self, model: str, credentials: dict, enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -56,10 +53,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -75,10 +69,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,24 +79,16 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: if len(texts) == 0: return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +101,52 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): + if not self._get_ai_model_entity(credentials["base_model_name"], model): raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +157,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index f9ddd86f68a9ec..bbad7264673d30 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -30,13 +31,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -50,14 +50,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name @@ -75,23 +74,29 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: st if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -108,10 +113,9 @@ def _process_sentence(self, sentence: str, model: str, return response.read() def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d923ed1..626fc811cfd47b 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `baichuan2-turbo` model for validate, - model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60f71c..bea6777f833a49 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,17 @@ class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index a8fd9dce91abbf..6e181ac5f8b3db 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -94,7 +94,6 @@ def generate( timeout: int, tools: Optional[list[PromptMessageTool]] = None, ) -> Union[Iterator, dict]: - if model in self._model_mapping.keys(): api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: @@ -120,9 +119,7 @@ def generate( err = resp["error"]["type"] msg = resp["error"]["message"] except Exception as e: - raise InternalServerError( - f"Failed to convert response to json: {e} with text: {response.text}" - ) + raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") if err == "invalid_api_key": raise InvalidAPIKeyError(msg) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a291c06..4e56e58d7eba15 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 36c7003d1b1ec9..3291fe2b2edf22 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -38,17 +38,16 @@ class BaichuanLanguageModel(LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, ) -> LLMResult | Generator: return self._generate( model=model, @@ -60,17 +59,17 @@ def _invoke( ) def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, ) -> int: return self._num_tokens_from_messages(prompt_messages) def _num_tokens_from_messages( - self, - messages: list[PromptMessage], + self, + messages: list[PromptMessage], ) -> int: """Calculate num tokens for baichuan model""" @@ -111,18 +110,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} else: raise ValueError(f"Unknown message type {type(message)}") @@ -146,15 +140,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError(f"Invalid API key: {e}") def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stream: bool = True, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, ) -> LLMResult | Generator: - instance = BaichuanModel(api_key=credentials["api_key"]) messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -169,23 +162,19 @@ def _generate( ) if stream: - return self._handle_chat_generate_stream_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) - return self._handle_chat_generate_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_response(model, prompt_messages, credentials, response) def _handle_chat_generate_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: dict, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, ) -> LLMResult: choices = response.get("choices", []) - assistant_message = AssistantPromptMessage(content='', tool_calls=[]) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if choices and choices[0]["finish_reason"] == "tool_calls": for choice in choices: for tool_call in choice["message"]["tool_calls"]: @@ -194,7 +183,7 @@ def _handle_chat_generate_response( type=tool_call.get("type", ""), function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.get("function", {}).get("name", ""), - arguments=tool_call.get("function", {}).get("arguments", "") + arguments=tool_call.get("function", {}).get("arguments", ""), ), ) assistant_message.tool_calls.append(tool) @@ -228,11 +217,11 @@ def _handle_chat_generate_response( ) def _handle_chat_generate_stream_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Iterator, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, ) -> Generator: for line in response: if not line: @@ -260,9 +249,7 @@ def _handle_chat_generate_stream_response( prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=choice["delta"]["content"], tool_calls=[] - ), + message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]), finish_reason=stop_reason, ), ) diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 81bd58e3ce1a55..b7276fabb5dd67 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -45,28 +46,23 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: # embedding chunk - chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user - ) + chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user) embeddings.extend(chunk_embeddings) token_usage += chunk_usage @@ -74,17 +70,14 @@ def _invoke(self, model: str, credentials: dict, result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -95,56 +88,47 @@ def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = :return: embeddings result """ url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': + elif err == "insufficient_quota": raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err == "invalid_authentication": + raise InvalidAuthenticationError(msg) + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] - + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,32 +154,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -207,10 +183,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -221,7 +194,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index e99bc52ff8258b..1cfc1d199cbf8d 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,13 +20,10 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `amazon.titan-text-lite-v1` model by default for validating credentials - model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') - model_instance.validate_credentials( - model=model_for_validation, - credentials=credentials - ) + model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1") + model_instance.validate_credentials(model=model_for_validation, credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c325ac3cecb295..a2a69b86bbf963 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -45,36 +45,42 @@ logger = logging.getLogger(__name__) -class BedrockLargeLanguageModel(LargeLanguageModel): +class BedrockLargeLanguageModel(LargeLanguageModel): # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed. - CONVERSE_API_ENABLED_MODEL_INFO=[ - {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} + CONVERSE_API_ENABLED_MODEL_INFO = [ + {"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, ] @staticmethod def _find_model_info(model_id): for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO: - if model_id.startswith(model['prefix']): + if model_id.startswith(model["prefix"]): return model logger.info(f"current model id: {model_id} did not support by Converse API") return None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -88,17 +94,28 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - - model_info= BedrockLargeLanguageModel._find_model_info(model) + + model_info = BedrockLargeLanguageModel._find_model_info(model) if model_info: - model_info['model'] = model + model_info["model"] = model # invoke models via boto3 converse API - return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) + return self._generate_with_converse( + model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools + ) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: + def _generate_with_converse( + self, + model_info: dict, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model with converse API @@ -110,35 +127,39 @@ def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_me :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client(service_name='bedrock-runtime', - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"]) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key"), + region_name=credentials["aws_region"], + ) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) parameters = { - 'modelId': model_info['model'], - 'messages': prompt_message_dicts, - 'inferenceConfig': inference_config, - 'additionalModelRequestFields': additional_model_fields, + "modelId": model_info["model"], + "messages": prompt_message_dicts, + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_model_fields, } - if model_info['support_system_prompts'] and system and len(system) > 0: - parameters['system'] = system + if model_info["support_system_prompts"] and system and len(system) > 0: + parameters["system"] = system - if model_info['support_tool_use'] and tools: - parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + if model_info["support_tool_use"] and tools: + parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: if stream: response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_stream_response( + model_info["model"], credentials, response, prompt_messages + ) else: response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: @@ -149,8 +170,10 @@ def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_me except Exception as ex: raise InvokeError(str(ex)) - def _handle_converse_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_converse_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -160,36 +183,30 @@ def _handle_converse_response(self, model: str, credentials: dict, response: dic :param prompt_messages: prompt messages :return: full response chunk generator result """ - response_content = response['output']['message']['content'] + response_content = response["output"]["message"]["content"] # transform assistant message to prompt message - if response['stopReason'] == 'tool_use': + if response["stopReason"] == "tool_use": tool_calls = [] text, tool_use = self._extract_tool_use(response_content) tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=json.dumps(tool_use['input']) - ) + name=tool_use["name"], arguments=json.dumps(tool_use["input"]) + ), ) tool_calls.append(tool_call) - assistant_prompt_message = AssistantPromptMessage( - content=text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls) else: - assistant_prompt_message = AssistantPromptMessage( - content=response_content[0]['text'] - ) + assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"]) # calculate num tokens - if response['usage']: + if response["usage"]: # transform usage - prompt_tokens = response['usage']['inputTokens'] - completion_tokens = response['usage']['outputTokens'] + prompt_tokens = response["usage"]["inputTokens"] + completion_tokens = response["usage"]["outputTokens"] else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -206,20 +223,25 @@ def _handle_converse_response(self, model: str, credentials: dict, response: dic ) return result - def _extract_tool_use(self, content:dict)-> tuple[str, dict]: + def _extract_tool_use(self, content: dict) -> tuple[str, dict]: tool_use = {} - text = '' + text = "" for item in content: - if 'toolUse' in item: - tool_use = item['toolUse'] - elif 'text' in item: - text = item['text'] + if "toolUse" in item: + tool_use = item["toolUse"] + elif "text" in item: + text = item["text"] else: raise ValueError(f"Got unknown item: {item}") return text, tool_use - def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_converse_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -231,7 +253,7 @@ def _handle_converse_stream_response(self, model: str, credentials: dict, respon """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -240,87 +262,85 @@ def _handle_converse_stream_response(self, model: str, credentials: dict, respon tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_use = {} - for chunk in response['stream']: - if 'messageStart' in chunk: + for chunk in response["stream"]: + if "messageStart" in chunk: return_model = model - elif 'messageStop' in chunk: - finish_reason = chunk['messageStop']['stopReason'] - elif 'contentBlockStart' in chunk: - tool = chunk['contentBlockStart']['start']['toolUse'] - tool_use['toolUseId'] = tool['toolUseId'] - tool_use['name'] = tool['name'] - elif 'metadata' in chunk: - input_tokens = chunk['metadata']['usage']['inputTokens'] - output_tokens = chunk['metadata']['usage']['outputTokens'] + elif "messageStop" in chunk: + finish_reason = chunk["messageStop"]["stopReason"] + elif "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"]["toolUse"] + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + elif "metadata" in chunk: + input_tokens = chunk["metadata"]["usage"]["inputTokens"] + output_tokens = chunk["metadata"]["usage"]["outputTokens"] usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - elif 'contentBlockDelta' in chunk: - delta = chunk['contentBlockDelta']['delta'] - if 'text' in delta: - chunk_text = delta['text'] if delta['text'] else '' + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + chunk_text = delta["text"] if delta["text"] else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) - index = chunk['contentBlockDelta']['contentBlockIndex'] + index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index+1, + index=index + 1, message=assistant_prompt_message, - ) + ), ) - elif 'toolUse' in delta: - if 'input' not in tool_use: - tool_use['input'] = '' - tool_use['input'] += delta['toolUse']['input'] - elif 'contentBlockStop' in chunk: - if 'input' in tool_use: + elif "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "contentBlockStop" in chunk: + if "input" in tool_use: tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=tool_use['input'] - ) + name=tool_use["name"], arguments=tool_use["input"] + ), ) tool_calls.append(tool_call) tool_use = {} except Exception as ex: raise InvokeError(str(ex)) - - def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]: + + def _convert_converse_api_model_parameters( + self, model_parameters: dict, stop: Optional[list[str]] = None + ) -> tuple[dict, dict]: inference_config = {} additional_model_fields = {} - if 'max_tokens' in model_parameters: - inference_config['maxTokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters: + inference_config["maxTokens"] = model_parameters["max_tokens"] + + if "temperature" in model_parameters: + inference_config["temperature"] = model_parameters["temperature"] - if 'temperature' in model_parameters: - inference_config['temperature'] = model_parameters['temperature'] - - if 'top_p' in model_parameters: - inference_config['topP'] = model_parameters['temperature'] + if "top_p" in model_parameters: + inference_config["topP"] = model_parameters["temperature"] if stop: - inference_config['stopSequences'] = stop - - if 'top_k' in model_parameters: - additional_model_fields['top_k'] = model_parameters['top_k'] - + inference_config["stopSequences"] = stop + + if "top_k" in model_parameters: + additional_model_fields["top_k"] = model_parameters["top_k"] + return inference_config, additional_model_fields def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: @@ -332,7 +352,7 @@ def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage] prompt_message_dicts = [] for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() system.append({"text": message.content}) else: prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) @@ -349,15 +369,13 @@ def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] "toolSpec": { "name": tool.name, "description": tool.description, - "inputSchema": { - "json": tool.parameters - } + "inputSchema": {"json": tool.parameters}, } } ) tool_config["tools"] = configs return tool_config - + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict @@ -365,15 +383,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "user", "content": [{'text': message.content}]} + message_dict = {"role": "user", "content": [{"text": message.content}]} else: sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -384,7 +400,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: image_content = requests.get(url).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -394,16 +410,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: image_content = base64.b64decode(base64_data) if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { - "image": { - "format": mime_type.replace('image/', ''), - "source": { - "bytes": image_content - } - } + "image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}} } sub_messages.append(sub_message_dict) @@ -412,36 +425,46 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(AssistantPromptMessage, message) if message.tool_calls: message_dict = { - "role": "assistant", "content":[{ - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": json.loads(message.tool_calls[0].function.arguments) + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": json.loads(message.tool_calls[0].function.arguments), + } } - }] + ], } else: - message_dict = {"role": "assistant", "content": [{'text': message.content}]} + message_dict = {"role": "assistant", "content": [{"text": message.content}]} elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = [{'text': message.content}] + message_dict = [{"text": message.content}] elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"json": {"text": message.content}}] - } - }] + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": {"text": message.content}}], + } + } + ], } else: raise ValueError(f"Got unknown type {message}") return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -451,15 +474,14 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] - model_name = model.split('.')[1] - + prefix = model.split(".")[0] + model_name = model.split(".")[1] + if isinstance(prompt_messages, str): prompt = prompt_messages else: prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) - return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -482,24 +504,28 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "topP": 0.9, "maxTokens": 32, } - + try: ping_message = UserPromptMessage(content="ping") - self._invoke(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters=required_params, - stream=False) - + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters=required_params, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Convert a single message to a string. @@ -514,7 +540,7 @@ def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str if isinstance(message, UserPromptMessage): body = content - if (isinstance(content, list)): + if isinstance(content, list): body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT]) message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}" elif isinstance(message, AssistantPromptMessage): @@ -528,7 +554,9 @@ def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -537,27 +565,31 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefi :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message, model_prefix, model_name) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} - model_prefix = model.split('.')[0] - model_name = model.split('.')[1] + model_prefix = model.split(".")[0] + model_name = model.split(".")[1] if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") @@ -571,21 +603,27 @@ def _create_payload(self, model: str, prompt_messages: list[PromptMessage], mode payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -598,18 +636,16 @@ def _generate(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming @@ -619,18 +655,13 @@ def _generate(self, model: str, credentials: dict, invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) - response = invoke( - modelId=model, - contentType="application/json", - accept= "*/*", - body=body_jsonstr - ) + body_jsonstr = json.dumps(payload) + response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -639,15 +670,15 @@ def _generate(self, model: str, credentials: dict, except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -657,7 +688,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -665,25 +696,23 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens(model, credentials, output if output else "") + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -698,8 +727,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -709,65 +739,59 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta if content_delta else "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: @@ -777,18 +801,15 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -804,9 +825,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -822,7 +843,12 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index ef22a9c86824e0..2d898e3aaacfec 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -27,12 +27,11 @@ logger = logging.getLogger(__name__) -class BedrockTextEmbeddingModel(TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: +class BedrockTextEmbeddingModel(TextEmbeddingModel): + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -42,67 +41,56 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) embeddings = [] token_usage = 0 - - model_prefix = model.split('.')[0] - - if model_prefix == "amazon" : + + model_prefix = model.split(".")[0] + + if model_prefix == "amazon": for text in texts: body = { - "inputText": text, + "inputText": text, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') + embeddings.extend([response_body.get("embedding")]) + token_usage += response_body.get("inputTextTokenCount") + logger.warning(f"Total Tokens: {token_usage}") result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' + if model_prefix == "cohere": + input_type = "search_document" if len(texts) > 1 else "search_query" for text in texts: body = { - "texts": [text], - "input_type": input_type, + "texts": [text], + "input_type": input_type, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) + embeddings.extend(response_body.get("embeddings")) token_usage += len(text) result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - #others + # others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -125,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -141,19 +129,25 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + + def _create_payload( + self, + model_prefix: str, + texts: list[str], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} if model_prefix == "amazon": - payload['inputText'] = texts + payload["inputText"] = texts - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -165,10 +159,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +170,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -199,31 +190,37 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) return InvokeError(error_msg) - - def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): - accept = 'application/json' - content_type = 'application/json' + def _invoke_bedrock_embedding( + self, + model: str, + bedrock_runtime, + body: dict, + ): + accept = "application/json" + content_type = "application/json" try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, - contentType=content_type + body=json.dumps(body), modelId=model, accept=accept, contentType=content_type ) - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) return response_body except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) diff --git a/api/core/model_runtime/model_providers/chatglm/chatglm.py b/api/core/model_runtime/model_providers/chatglm/chatglm.py index e9dd5794f31ce2..71d9a1532281bd 100644 --- a/api/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/api/core/model_runtime/model_providers/chatglm/chatglm.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `chatglm3-6b` model for validate, - model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials - ) + model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index e83d08af714469..114acc1ec3a942 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,12 +43,19 @@ logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +78,16 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +108,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +141,24 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError + PermissionDeniedError, ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -163,35 +180,31 @@ def _generate(self, model: str, credentials: dict, extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +225,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +236,12 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +252,14 @@ def _extract_response_tool_calls(self, if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +273,20 @@ def _to_client_kwargs(self, credentials: dict) -> dict: client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['api_base']) / 'v1') + "base_url": str(URL(credentials["api_base"]) / "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,9 +294,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: @@ -295,23 +306,25 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +333,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +348,15 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -359,15 +376,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -378,7 +394,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response ) return response - + def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +411,19 @@ def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageT return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +432,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +470,37 @@ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/api/core/model_runtime/model_providers/cohere/cohere.py b/api/core/model_runtime/model_providers/cohere/cohere.py index cfbcb94d2624f1..8394a45fcf9ca1 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.py +++ b/api/core/model_runtime/model_providers/cohere/cohere.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.RERANK) # Use `rerank-english-v2.0` model for validate, - model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials - ) + model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 89b04c0279f760..203ca9c4a00f66 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -85,7 +91,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -95,11 +101,16 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -136,30 +147,37 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -173,17 +191,17 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop if stream: response = client.generate_stream( prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -192,14 +210,14 @@ def _generate(self, model: str, credentials: dict, prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generation, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -212,9 +230,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens prompt_tokens = int(response.meta.billed_units.input_tokens) @@ -225,17 +241,18 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[GenerateStreamedResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[GenerateStreamedResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -245,7 +262,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, GenerateStreamedResponse_TextGeneration): chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) @@ -255,9 +272,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -267,7 +282,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -277,9 +292,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) completion_tokens = self._num_tokens_from_messages( - model, - credentials, - [AssistantPromptMessage(content=full_assistant_content)] + model, credentials, [AssistantPromptMessage(content=full_assistant_content)] ) # transform usage @@ -290,20 +303,27 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=chunk.finish_reason, - usage=usage - ) + usage=usage, + ), ) break elif isinstance(chunk, GenerateStreamedResponse_StreamError): chunk = cast(GenerateStreamedResponse_StreamError, chunk) raise InvokeBadRequestError(chunk.err) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -318,27 +338,28 @@ def _chat_generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['stop_sequences'] = stop + model_parameters["stop_sequences"] = stop if tools: if len(tools) == 1: raise ValueError("Cohere tool call requires at least two tools to be specified.") - model_parameters['tools'] = self._convert_tools(tools) + model_parameters["tools"] = self._convert_tools(tools) - message, chat_histories, tool_results \ - = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories( + prompt_messages + ) if tool_results: - model_parameters['tool_results'] = tool_results + model_parameters["tool_results"] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") if stream: response = client.chat_stream( @@ -346,7 +367,7 @@ def _chat_generate(self, model: str, credentials: dict, chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -356,14 +377,14 @@ def _chat_generate(self, model: str, credentials: dict, chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -380,19 +401,15 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response for cohere_tool_call in response.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) @@ -403,17 +420,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[StreamedChatResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -423,17 +441,16 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, :return: llm response chunk generator """ - def final_response(full_text: str, - tool_calls: list[AssistantPromptMessage.ToolCall], - index: int, - finish_reason: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text, - tool_calls=tool_calls - ) + full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage @@ -444,14 +461,14 @@ def final_response(full_text: str, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content='', tool_calls=tool_calls), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" tool_calls = [] for chunk in response: if isinstance(chunk, StreamedChatResponse_TextGeneration): @@ -462,9 +479,7 @@ def final_response(full_text: str, continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -474,7 +489,7 @@ def final_response(full_text: str, delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -484,11 +499,10 @@ def final_response(full_text: str, for cohere_tool_call in chunk.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) elif isinstance(chunk, StreamedChatResponse_StreamEnd): @@ -496,8 +510,9 @@ def final_response(full_text: str, yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -510,13 +525,14 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages prompt_message = cast(AssistantPromptMessage, prompt_message) if prompt_message.tool_calls: for tool_call in prompt_message.tool_calls: - latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( - call=ToolCall( - name=tool_call.function.name, - parameters=json.loads(tool_call.function.arguments) - ), - outputs=[] - )) + latest_tool_call_n_outputs.append( + ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments) + ), + outputs=[], + ) + ) else: cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) if cohere_prompt_message: @@ -529,12 +545,9 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages if tool_call_n_outputs.call.name == prompt_message.tool_call_id: latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( call=ToolCall( - name=tool_call_n_outputs.call.name, - parameters=tool_call_n_outputs.call.parameters + name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters ), - outputs=[{ - "result": prompt_message.content - }] + outputs=[{"result": prompt_message.content}], ) break i += 1 @@ -556,7 +569,7 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages latest_message = chat_histories.pop() message = latest_message.message else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories, latest_tool_call_n_outputs @@ -569,7 +582,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[Ch if isinstance(message.content, str): chat_message = ChatMessage(role="USER", message=message.content) else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -597,8 +610,8 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: """ cohere_tools = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] parameter_definitions = {} for p_key, p_val in properties.items(): @@ -606,21 +619,16 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: if p_key in required_properties: required = True - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" parameter_definitions[p_key] = ToolParameterDefinitionsValue( - description=desc, - type=p_val['type'], - required=required + description=desc, type=p_val["type"], required=required ) cohere_tool = Tool( - name=tool.name, - description=tool.description, - parameter_definitions=parameter_definitions + name=tool.name, description=tool.description, parameter_definitions=parameter_definitions ) cohere_tools.append(cohere_tool) @@ -637,12 +645,9 @@ def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> i :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return len(response.tokens) @@ -658,30 +663,30 @@ def _num_tokens_from_messages(self, model: str, credentials: dict, messages: lis real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -691,16 +696,13 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -716,22 +718,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index d2fdb30c6feec9..aba8fedbc097e5 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -38,20 +44,17 @@ def _invoke(self, model: str, credentials: dict, :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) response = client.rerank( query=query, documents=docs, model=model, top_n=top_n, return_documents=True, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) rerank_documents = [] @@ -70,10 +73,7 @@ def _invoke(self, model: str, credentials: dict, else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -94,7 +94,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -110,22 +110,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 0540fb740f7e90..a1c5e98118e4c6 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -46,14 +46,10 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text - ) + tokenize_response = self._tokenize(model=model, credentials=credentials, text=text) for j in range(0, len(tokenize_response), context_size): - tokens += [tokenize_response[j: j + context_size]] + tokens += [tokenize_response[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -62,9 +58,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]] ) used_tokens += embedding_used_tokens @@ -80,9 +74,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -92,17 +84,9 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,14 +100,10 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: - response = self._tokenize( - model=model, - credentials=credentials, - text=full_text - ) + response = self._tokenize(model=model, credentials=credentials, text=full_text) except Exception as e: raise self._transform_invoke_error(e) @@ -141,14 +121,9 @@ def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]: return [] # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model, - offline=False, - request_options=RequestOptions(max_retries=0) - ) + response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0)) return response.token_strings @@ -162,11 +137,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -180,14 +151,14 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query', - request_options=RequestOptions(max_retries=1) + input_type="search_document" if len(texts) > 1 else "search_query", + request_options=RequestOptions(max_retries=1), ) return response.embeddings, int(response.meta.billed_units.input_tokens) @@ -203,10 +174,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -217,7 +185,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -233,22 +201,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py index d61fd4ddc80457..10feef897272db 100644 --- a/api/core/model_runtime/model_providers/deepseek/deepseek.py +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -7,9 +7,7 @@ logger = logging.getLogger(__name__) - class DeepSeekProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `deepseek-chat` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='deepseek-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index bdb3823b60e739..6d0a3ee2628ea2 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -13,12 +13,17 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -27,10 +32,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -48,8 +51,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -69,10 +73,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -103,11 +107,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.deepseek.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.deepseek.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" - + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py index 5f282702bb03ef..e69de29bb2d1d6 100644 --- a/api/core/model_runtime/model_providers/fishaudio/__init__.py +++ b/api/core/model_runtime/model_providers/fishaudio/__init__.py @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py index 9f80996d9dcdef..3bc4b533e0815a 100644 --- a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -1,4 +1,4 @@ -import logging +import logging from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -18,11 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: """ try: model_instance = self.get_model_instance(ModelType.TTS) - model_instance.validate_credentials( - credentials=credentials - ) + model_instance.validate_credentials(credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py index 5b673ce18662d8..895a7a914c97de 100644 --- a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional import httpx @@ -12,9 +12,7 @@ class FishAudioText2SpeechModel(TTSModel): Model class for Fish.audio Text to Speech model. """ - def get_tts_model_voices( - self, model: str, credentials: dict, language: Optional[str] = None - ) -> list: + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: api_base = credentials.get("api_base", "https://api.fish.audio") api_key = credentials.get("api_key") use_public_models = credentials.get("use_public_models", "false") == "true" @@ -68,9 +66,7 @@ def _invoke( voice=voice, ) - def validate_credentials( - self, credentials: dict, user: Optional[str] = None - ) -> None: + def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: """ Validate credentials for text2speech model @@ -91,9 +87,7 @@ def validate_credentials( except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming( - self, model: str, credentials: dict, content_text: str, voice: str - ) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ Invoke streaming text2speech model :param model: model name @@ -106,12 +100,10 @@ def _tts_invoke_streaming( try: word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: - sentences = self._split_text_into_sentences( - content_text, max_length=word_limit - ) + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) else: sentences = [content_text.strip()] - + for i in range(len(sentences)): yield from self._tts_invoke_streaming_sentence( credentials=credentials, content_text=sentences[i], voice=voice @@ -120,9 +112,7 @@ def _tts_invoke_streaming( except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _tts_invoke_streaming_sentence( - self, credentials: dict, content_text: str, voice: Optional[str] = None - ) -> any: + def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any: """ Invoke streaming text2speech model @@ -141,20 +131,14 @@ def _tts_invoke_streaming_sentence( with httpx.stream( "POST", api_url + "/v1/tts", - json={ - "text": content_text, - "reference_id": voice, - "latency": latency - }, + json={"text": content_text, "reference_id": voice, "latency": latency}, headers={ "Authorization": f"Bearer {api_key}", }, timeout=None, ) as response: if response.status_code != 200: - raise InvokeBadRequestError( - f"Error: {response.status_code} - {response.text}" - ) + raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}") yield from response.iter_bytes() @property diff --git a/api/core/model_runtime/model_providers/google/google.py b/api/core/model_runtime/model_providers/google/google.py index ba25c74e71e857..70f56a8337b2e6 100644 --- a/api/core/model_runtime/model_providers/google/google.py +++ b/api/core/model_runtime/model_providers/google/google.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-pro` model for validate, - model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-pro", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 11f9f32f96dd14..274ff02095b105 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -49,12 +49,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -70,9 +75,14 @@ def _invoke(self, model: str, credentials: dict, """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -85,7 +95,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -95,13 +105,10 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -117,14 +124,16 @@ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -136,20 +145,25 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -163,14 +177,12 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -180,7 +192,7 @@ def _generate(self, model: str, credentials: dict, content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) @@ -194,7 +206,7 @@ def _generate(self, model: str, credentials: dict, google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -203,13 +215,11 @@ def _generate(self, model: str, credentials: dict, response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600} + request_options={"timeout": 600}, ) if stream: @@ -217,8 +227,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -229,9 +240,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -250,8 +259,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -264,9 +274,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index = -1 for chunk in response: for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -275,36 +283,31 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - + if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -312,8 +315,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=index, message=assistant_prompt_message, finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -328,9 +331,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -353,65 +354,61 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: :return: glm Content representation of message """ if isinstance(message, UserPromptMessage): - glm_content = { - "role": "user", - "parts": [] - } - if (isinstance(message.content, str)): - glm_content['parts'].append(to_part(message.content)) + glm_content = {"role": "user", "parts": []} + if isinstance(message.content, str): + glm_content["parts"].append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: - glm_content['parts'].append(to_part(c.data)) + glm_content["parts"].append(to_part(c.data)) elif c.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, c) if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, base64_data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] else: # fetch image data from url try: image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} - glm_content['parts'].append(blob) + blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): - glm_content = { - "role": "model", - "parts": [] - } + glm_content = {"role": "model", "parts": []} if message.content: - glm_content['parts'].append(to_part(message.content)) + glm_content["parts"].append(to_part(message.content)) if message.tool_calls: - glm_content["parts"].append(to_part(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))) + glm_content["parts"].append( + to_part( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ) return glm_content elif isinstance(message, SystemPromptMessage): - return { - "role": "user", - "parts": [to_part(message.content)] - } + return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", - "parts": [glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))] + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], } else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -423,25 +420,20 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -457,5 +449,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index b3f37b39678388..d0d5ff68f8090e 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama3-8b-8192', - credentials=credentials - ) + model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a7e0d..352a7b519ee168 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -7,11 +7,17 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,6 +27,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index dd8ae526e6a759..3c4020b6eedf24 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -4,12 +4,6 @@ class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 15e2a4fed41be7..54d2a2bf399623 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -6,6 +6,5 @@ class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f43a8aedaf2c69..10c6d553f3d633 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -29,16 +29,23 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] + + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,98 +54,100 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: return self._handle_generate_stream_response(model, credentials, prompt_messages, response) return self._handle_generate_response(model, credentials, prompt_messages, response) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in ("text2text-generation", "text-generation"): + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, " "text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: - temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +157,15 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,30 +175,30 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, @@ -197,11 +206,9 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,9 +217,7 @@ def _handle_generate_stream_response(self, index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -240,15 +245,15 @@ def _handle_generate_stream_response(self, ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -270,15 +275,14 @@ def _get_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str): try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -287,10 +291,7 @@ def _get_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str): def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 0f0c166f3ec179..cb7a30bbe51343 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -13,40 +13,30 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( - json={ - "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } - }, - model=execute_model) + json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model + ) embeddings = json.loads(output.decode()) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -56,52 +46,48 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.") self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -128,24 +114,20 @@ def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -156,7 +138,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +148,26 @@ def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]: model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py index 94544662503974..97d7e28dc646f8 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -6,6 +6,5 @@ class HuggingfaceTeiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 34013426de5b77..c128c35f6db4c4 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -47,29 +47,29 @@ def _invoke( """ if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] try: results = TeiHelper.invoke_rerank(server_url, query, docs) rerank_documents = [] - for result in results: + for result in results: rerank_document = RerankDocument( - index=result['index'], - text=result['text'], - score=result['score'], + index=result["index"], + text=result["text"], + score=result["score"], ) - if score_threshold is None or result['score'] >= score_threshold: + if score_threshold is None or result["score"] >= score_threshold: rerank_documents.append(rerank_document) if top_n is not None and len(rerank_documents) >= top_n: break return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -80,21 +80,21 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) - if extra_args.model_type != 'reranker': - raise CredentialsValidateFailedError('Current model is not a rerank model') + if extra_args.model_type != "reranker": + raise CredentialsValidateFailedError("Current model is not a rerank model") - credentials['context_size'] = extra_args.max_input_length + credentials["context_size"] = extra_args.max_input_length self.invoke( model=model, credentials=credentials, - query='Whose kasumi', + query="Whose kasumi", docs=[ 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', - 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', - 'and she leads a team named PopiParty.', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", ], score_threshold=0.8, ) @@ -129,7 +129,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 2aa785c89d27e6..56c51e8888636c 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -31,16 +31,16 @@ def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraPa with cache_lock: if model_name not in cache: cache[model_name] = { - 'expires': time() + 300, - 'value': TeiHelper._get_tei_extra_parameter(server_url), + "expires": time() + 300, + "value": TeiHelper._get_tei_extra_parameter(server_url), } - return cache[model_name]['value'] + return cache[model_name]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -52,40 +52,38 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: get tei model extra parameter like model_type, max_input_length, max_batch_requests """ - url = str(URL(server_url) / 'info') + url = str(URL(server_url) / "info") # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: raise RuntimeError( - f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}" ) response_json = response.json() - model_type = response_json.get('model_type', {}) + model_type = response_json.get("model_type", {}) if len(model_type.keys()) < 1: - raise RuntimeError('model_type is empty') + raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ['embedding', 'reranker']: - raise RuntimeError(f'invalid model_type: {model_type}') - - max_input_length = response_json.get('max_input_length', 512) - max_client_batch_size = response_json.get('max_client_batch_size', 1) + if model_type not in ["embedding", "reranker"]: + raise RuntimeError(f"invalid model_type: {model_type}") + + max_input_length = response_json.get("max_input_length", 512) + max_client_batch_size = response_json.get("max_client_batch_size", 1) return TeiModelExtraParameter( - model_type=model_type, - max_input_length=max_input_length, - max_client_batch_size=max_client_batch_size + model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size ) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: """ @@ -116,12 +114,12 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: :param texts: texts to tokenize """ resp = httpx.post( - f'{server_url}/tokenize', - json={'inputs': texts}, + f"{server_url}/tokenize", + json={"inputs": texts}, ) resp.raise_for_status() return resp.json() - + @staticmethod def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ @@ -149,8 +147,8 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ # Use OpenAI compatible API here, which has usage tracking resp = httpx.post( - f'{server_url}/v1/embeddings', - json={'input': texts}, + f"{server_url}/v1/embeddings", + json={"input": texts}, ) resp.raise_for_status() return resp.json() @@ -173,11 +171,11 @@ def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: :param texts: texts to rerank :param candidates: candidates to rerank """ - params = {'query': query, 'texts': docs, 'return_text': True} + params = {"query": query, "texts": docs, "return_text": True} response = httpx.post( - server_url + '/rerank', + server_url + "/rerank", json=params, ) - response.raise_for_status() + response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 6897b87f6d7525..2d04abb277fcee 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -40,12 +40,11 @@ def _invoke( :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] - # get model properties context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -58,7 +57,6 @@ def _invoke( batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): - # Check if the number of tokens is larger than the context size num_tokens = len(tokenize_result) @@ -66,20 +64,22 @@ def _invoke( # Find the best cutoff point pre_special_token_count = 0 for token in tokenize_result: - if token['special']: + if token["special"]: pre_special_token_count += 1 else: break - rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + rest_special_token_count = ( + len([token for token in tokenize_result if token["special"]]) - pre_special_token_count + ) # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit token_cutoff = context_size - rest_special_token_count - 20 # Find the cutoff index cutpoint_token = tokenize_result[token_cutoff] - cutoff = cutpoint_token['start'] + cutoff = cutpoint_token["start"] - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -92,12 +92,12 @@ def _invoke( for i in _iter: iter_texts = inputs[i : i + max_chunks] results = TeiHelper.invoke_embeddings(server_url, iter_texts) - embeddings = results['data'] - embeddings = [embedding['embedding'] for embedding in embeddings] + embeddings = results["data"] + embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) - usage = results['usage'] - used_tokens += usage['total_tokens'] + usage = results["usage"] + used_tokens += usage["total_tokens"] except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -117,9 +117,9 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int :return: """ num_tokens = 0 - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) @@ -135,15 +135,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) print(extra_args) - if extra_args.model_type != 'embedding': - raise CredentialsValidateFailedError('Current model is not a embedding model') + if extra_args.model_type != "embedding": + raise CredentialsValidateFailedError("Current model is not a embedding model") - credentials['context_size'] = extra_args.max_input_length - credentials['max_chunks'] = extra_args.max_client_batch_size - self._invoke(model=model, credentials=credentials, texts=['ping']) + credentials["context_size"] = extra_args.max_input_length + credentials["max_chunks"] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -195,8 +195,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py index 5a298d33acac5c..e65772e7dda3a1 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class HunyuanProvider(ModelProvider): +class HunyuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +19,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `hunyuan-standard` model for validate, - model_instance.validate_credentials( - model='hunyuan-standard', - credentials=credentials - ) + model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 0bdf6ec005056b..c056ab7a0855c7 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -23,21 +23,27 @@ logger = logging.getLogger(__name__) -class HunyuanLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: +class HunyuanLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = self._setup_hunyuan_client(credentials) request = models.ChatCompletionsRequest() messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages) custom_parameters = { - 'Temperature': model_parameters.get('temperature', 0.0), - 'TopP': model_parameters.get('top_p', 1.0), - 'EnableEnhancement': model_parameters.get('enable_enhance', True) + "Temperature": model_parameters.get("temperature", 0.0), + "TopP": model_parameters.get("top_p", 1.0), + "EnableEnhancement": model_parameters.get("enable_enhance", True), } params = { @@ -47,16 +53,19 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes **custom_parameters, } # add Tools and ToolChoice - if (tools and len(tools) > 0): - params['ToolChoice'] = "auto" - params['Tools'] = [{ - "Type": "function", - "Function": { - "Name": tool.name, - "Description": tool.description, - "Parameters": json.dumps(tool.parameters) + if tools and len(tools) > 0: + params["ToolChoice"] = "auto" + params["Tools"] = [ + { + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters), + }, } - } for tool in tools] + for tool in tools + ] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -76,22 +85,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -106,92 +112,96 @@ def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage] for message in prompt_messages: if isinstance(message, AssistantPromptMessage): tool_calls = message.tool_calls - if (tool_calls and len(tool_calls) > 0): + if tool_calls and len(tool_calls) > 0: dict_tool_calls = [ { "Id": tool_call.id, "Type": tool_call.type, "Function": { "Name": tool_call.function.name, - "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" - } - } for tool_call in tool_calls] - - dict_list.append({ - "Role": message.role.value, - # fix set content = "" while tool_call request - # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. - "Content": " ", # message.content if (message.content is not None) else "", - "ToolCalls": dict_tool_calls - }) + "Arguments": tool_call.function.arguments + if (tool_call.function.arguments == "") + else "{}", + }, + } + for tool_call in tool_calls + ] + + dict_list.append( + { + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls, + } + ) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) elif isinstance(message, ToolPromptMessage): - tool_execute_result = { "result": message.content } - content =json.dumps(tool_execute_result, ensure_ascii=False) - dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + tool_execute_result = {"result": message.content} + content = json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id}) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): - tool_call = None tool_calls = [] for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) - data_str = event['data'] + data_str = event["data"] data = json.loads(data_str) - choices = data.get('Choices', []) + choices = data.get("Choices", []) if not choices: continue choice = choices[0] - delta = choice.get('Delta', {}) - message_content = delta.get('Content', '') - finish_reason = choice.get('FinishReason', '') + delta = choice.get("Delta", {}) + message_content = delta.get("Content", "") + finish_reason = choice.get("FinishReason", "") - usage = data.get('Usage', {}) - prompt_tokens = usage.get('PromptTokens', 0) - completion_tokens = usage.get('CompletionTokens', 0) + usage = data.get("Usage", {}) + prompt_tokens = usage.get("PromptTokens", 0) + completion_tokens = usage.get("CompletionTokens", 0) - response_tool_calls = delta.get('ToolCalls') - if (response_tool_calls is not None): + response_tool_calls = delta.get("ToolCalls") + if response_tool_calls is not None: new_tool_calls = self._extract_response_tool_calls(response_tool_calls) - if (len(new_tool_calls) > 0): + if len(new_tool_calls) > 0: new_tool_call = new_tool_calls[0] - if (tool_call is None): tool_call = new_tool_call - elif (tool_call.id != new_tool_call.id): + if tool_call is None: + tool_call = new_tool_call + elif tool_call.id != new_tool_call.id: tool_calls.append(tool_call) tool_call = new_tool_call else: tool_call.function.name += new_tool_call.function.name tool_call.function.arguments += new_tool_call.function.arguments - if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0: tool_calls.append(tool_call) tool_call = None - assistant_prompt_message = AssistantPromptMessage( - content=message_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[]) # rewrite content = "" while tool_call to avoid show content on web page - if (len(tool_calls) > 0): assistant_prompt_message.content = "" - + if len(tool_calls) > 0: + assistant_prompt_message.content = "" + # add tool_calls to assistant_prompt_message - if (finish_reason == 'tool_calls'): + if finish_reason == "tool_calls": assistant_prompt_message.tool_calls = tool_calls tool_call = None tool_calls = [] - if (len(finish_reason) > 0): + if len(finish_reason) > 0: usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) delta_chunk = LLMResultChunkDelta( index=index, - role=delta.get('Role', 'assistant'), + role=delta.get("Role", "assistant"), message=assistant_prompt_message, usage=usage, finish_reason=finish_reason, @@ -212,8 +222,9 @@ def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp ) def _handle_chat_response(self, credentials, model, prompt_messages, response): - usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, - response.Usage.CompletionTokens) + usage = self._calc_response_usage( + model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens + ) assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( @@ -225,8 +236,13 @@ def _handle_chat_response(self, credentials, model, prompt_messages, response): return result - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if len(prompt_messages) == 0: return 0 prompt = self._convert_messages_to_prompt(prompt_messages) @@ -241,10 +257,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -287,10 +300,8 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeError: [TencentCloudSDKException], } - - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -300,17 +311,14 @@ def _extract_response_tool_calls(self, tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: - response_function = response_tool_call.get('Function', {}) + response_function = response_tool_call.get("Function", {}) function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function.get('Name', ''), - arguments=response_function.get('Arguments', '') + name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get('Id', 0), - type='function', - function=function + id=response_tool_call.get("Id", 0), type="function", function=function ) tool_calls.append(tool_call) - return tool_calls \ No newline at end of file + return tool_calls diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 64d8dcf795f1c8..1396e59e188bca 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -19,14 +19,15 @@ logger = logging.getLogger(__name__) + class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for Hunyuan text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +38,9 @@ def _invoke(self, model: str, credentials: dict, :return: embeddings result """ - if model != 'hunyuan-embedding': - raise ValueError('Invalid model name') - + if model != "hunyuan-embedding": + raise ValueError("Invalid model name") + client = self._setup_hunyuan_client(credentials) embeddings = [] @@ -47,9 +48,7 @@ def _invoke(self, model: str, credentials: dict, for input in texts: request = models.GetEmbeddingRequest() - params = { - "Input": input - } + params = {"Input": input} request.from_json_string(json.dumps(params)) response = client.GetEmbedding(request) usage = response.Usage.TotalTokens @@ -60,11 +59,7 @@ def _invoke(self, model: str, credentials: dict, result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result @@ -79,22 +74,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -102,7 +94,7 @@ def _setup_hunyuan_client(self, credentials): clientProfile.httpProfile = httpProfile client = hunyuan_client.HunyuanClient(cred, "", clientProfile) return client - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -114,10 +106,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -128,11 +117,11 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -146,7 +135,7 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeError: [TencentCloudSDKException], } - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -170,4 +159,4 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int # response = client.GetTokenCount(request) # num_tokens += response.TokenCount - return num_tokens \ No newline at end of file + return num_tokens diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index cde4313495b4a8..33977b6a33fa49 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -8,7 +8,6 @@ class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `jina-embeddings-v2-base-en` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials - ) + model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index de7e038b9f31a6..d8394f7a4c1359 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -40,37 +47,32 @@ def _invoke(self, model: str, credentials: dict, if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.jina.ai/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.jina.ai/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -81,7 +83,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -92,7 +93,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -105,23 +106,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9e929..d80cbfa83d6425 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ def _get_tokenizer(cls): with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 23203491e656fe..7ed3e4d384f27e 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,29 +39,23 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") - base_url = credentials.get('base_url', self.api_base) - if base_url.endswith('/'): + base_url = credentials.get("base_url", self.api_base) + if base_url.endswith("/"): base_url = base_url[:-1] - url = base_url + '/embeddings' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} def transform_jina_input_text(model, text): - if model == 'jina-clip-v1': + if model == "jina-clip-v1": return {"text": text} return text - data = { - 'model': model, - 'input': [transform_jina_input_text(model, text) for text in texts] - } + data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} try: response = post(url, headers=headers, data=dumps(data)) @@ -70,7 +65,7 @@ def transform_jina_input_text(model, text): if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -81,25 +76,20 @@ def transform_jina_input_text(model, text): raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -128,30 +118,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: - raise CredentialsValidateFailedError( - f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError, - InvokeBadRequestError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -165,10 +143,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,24 +154,21 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) return entity diff --git a/api/core/model_runtime/model_providers/leptonai/leptonai.py b/api/core/model_runtime/model_providers/leptonai/leptonai.py index b035c31ac51453..34a55ff1924cf8 100644 --- a/api/core/model_runtime/model_providers/leptonai/leptonai.py +++ b/api/core/model_runtime/model_providers/leptonai/leptonai.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class LeptonAIProvider(ModelProvider): +class LeptonAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama2-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="llama2-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/leptonai/llm/llm.py b/api/core/model_runtime/model_providers/leptonai/llm/llm.py index 523309bac579a3..3d69417e45da72 100644 --- a/api/core/model_runtime/model_providers/leptonai/llm/llm.py +++ b/api/core/model_runtime/model_providers/leptonai/llm/llm.py @@ -8,18 +8,25 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_PREFIX_MAP = { - 'llama2-7b': 'llama2-7b', - 'gemma-7b': 'gemma-7b', - 'mistral-7b': 'mistral-7b', - 'mixtral-8x7b': 'mixtral-8x7b', - 'llama3-70b': 'llama3-70b', - 'llama2-13b': 'llama2-13b', - } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + "llama2-7b": "llama2-7b", + "gemma-7b": "gemma-7b", + "mistral-7b": "mistral-7b", + "mixtral-8x7b": "mixtral-8x7b", + "llama3-70b": "llama3-70b", + "llama2-13b": "llama2-13b", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -29,6 +36,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1' - \ No newline at end of file + credentials["mode"] = "chat" + credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1" diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 1009995c5868a1..94c03efe7b5b31 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -52,29 +52,48 @@ class LocalAILanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ def tokens(text: str): """ - We could not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We could not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -87,10 +106,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -142,30 +161,30 @@ def tokens(text: str): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -180,102 +199,104 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048")) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] - if completion_type == 'chat_completion': + if completion_type == "chat_completion": result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model_name, @@ -283,36 +304,32 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) def _to_client_kwargs(self, credentials: dict) -> dict: @@ -322,13 +339,13 @@ def _to_client_kwargs(self, credentials: dict) -> dict: :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -349,7 +366,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -359,11 +376,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}], } else: raise ValueError(f"Unknown message type {type(message)}") @@ -374,27 +387,29 @@ def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMes """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -411,18 +426,16 @@ def _handle_completion_generate_response(self, model: str, assistant_message = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -434,11 +447,14 @@ def _handle_completion_generate_response(self, model: str, return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -459,16 +475,14 @@ def _handle_chat_generate_response(self, model: str, tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -480,12 +494,15 @@ def _handle_chat_generate_response(self, model: str, return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -494,17 +511,11 @@ def _handle_completion_generate_stream_response(self, model: str, delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) @@ -512,8 +523,12 @@ def _handle_completion_generate_stream_response(self, model: str, completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -523,7 +538,7 @@ def _handle_completion_generate_stream_response(self, model: str, index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -539,12 +554,15 @@ def _handle_completion_generate_stream_response(self, model: str, full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -552,7 +570,7 @@ def _handle_chat_generate_stream_response(self, model: str, delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -564,22 +582,24 @@ def _handle_chat_generate_stream_response(self, model: str, # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -589,7 +609,7 @@ def _handle_chat_generate_stream_response(self, model: str, index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -605,9 +625,9 @@ def _handle_chat_generate_stream_response(self, model: str, full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -618,15 +638,10 @@ def _extract_response_tool_calls(self, if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls @@ -651,15 +666,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError + PermissionDeniedError, ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd541b1f..4ff898052b380d 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -6,6 +6,5 @@ class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py index c8ba9a6c7c1c8d..2b0f53bc19e8ec 100644 --- a/api/core/model_runtime/model_providers/localai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel): LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -43,45 +50,37 @@ def _invoke(self, model: str, credentials: dict, if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model - + if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - + raise CredentialsValidateFailedError("model_name is required") + url = server_url - headers = { - 'Authorization': f"Bearer {credentials.get('api_key')}", - 'Content-Type': 'application/json' - } + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} - data = { - "model": model_name, - "query": query, - "documents": docs, - "top_n": top_n - } + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} try: - response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) - response.raise_for_status() + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -92,7 +91,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -103,7 +101,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,21 +114,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={} + model_properties={}, ) return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py index d7403aff4ffa7f..4b9d0f5bfefd33 100644 --- a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -32,8 +32,8 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional :param user: unique user id :return: text for given audio file """ - - url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + + url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions") data = {"model": model} files = {"file": file} @@ -42,7 +42,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional prepared_request = session.prepare_request(request) response = session.send(prepared_request) - if 'error' in response.json(): + if "error" in response.json(): raise InvokeServerUnavailableError("Empty response") return response.json()["text"] @@ -58,7 +58,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -66,36 +66,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 954c9d10f2a67f..7d258be81e0580 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,39 +38,33 @@ def _invoke(self, model: str, credentials: dict, :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - + raise CredentialsValidateFailedError("model_name is required") + url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} - data = { - 'model': model_name, - 'input': texts[0] - } + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +74,21 @@ def _invoke(self, model: str, credentials: dict, else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -114,7 +107,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema @@ -130,10 +123,10 @@ def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIMod features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,32 +138,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -182,10 +165,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -196,7 +176,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5ed6b..96f99c892978e2 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -17,42 +17,48 @@ class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,40 +66,39 @@ def generate(self, model: str, api_key: str, group_id: str, # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] - - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise BadRequestError("At least one user message is required") + + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] + + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -110,65 +115,52 @@ def _handle_error(self, code: int, msg: str): def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens - } - message.stop_reason = data['choices'][0]['finish_reason'] + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens} + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + message = choice["delta"] + yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 55747057c9ff3b..0a2a67a56d78cd 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,86 +17,83 @@ class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') + raise InvalidAPIKeyError("Invalid API key or group ID") - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: - extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - - if model_parameters.get('plugin_web_search'): - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool: + extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if model_parameters.get("plugin_web_search"): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') + raise BadRequestError("At least one message is required") if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') + raise BadRequestError("At least one user message is required") messages = [message.to_dict() for message in prompt_messages] - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = {'type': 'auto'} + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) @@ -123,78 +120,72 @@ def _handle_error(self, code: int, msg: str): def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) # final chunk - if data['reply'] or data.get('usage'): - total_tokens = data['usage']['total_tokens'] - minimax_message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) + if data["reply"] or data.get("usage"): + total_tokens = data["usage"]["total_tokens"] + minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") minimax_message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - minimax_message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data["choices"][0]["finish_reason"] - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) > 0: for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append function_call message - if 'function_call' in message: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = message['function_call'] + if "function_call" in message: + function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message["function_call"] yield function_call_message yield minimax_message return # partial chunk - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append text message - if 'text' in message: - minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + if "text" in message: + minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value) yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index feeba75f49832f..76ed704a75540f 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,18 +34,25 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6.5s-chat': MinimaxChatCompletionPro, - 'abab6.5-chat': MinimaxChatCompletionPro, - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6.5s-chat": MinimaxChatCompletionPro, + "abab6.5-chat": MinimaxChatCompletionPro, + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -53,82 +60,97 @@ def validate_credentials(self, model: str, credentials: dict) -> None: Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to calculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) @@ -136,26 +158,27 @@ def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessa return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) elif isinstance(prompt_message, ToolPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -166,31 +189,33 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if "name" not in message.function_call or "arguments" not in message.function_call: continue yield LLMResultChunk( @@ -199,15 +224,16 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], arguments=message.function_call["arguments"] + ), ) - )] + ], ), ), ) @@ -217,10 +243,7 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) @@ -236,22 +259,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac20d0..88ebe5e2e00e7a 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,27 @@ class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: - return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call - } - + return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call} + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3a098..5a761903a1eb12 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `abab5.5-chat` model for validate, - model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise CredentialsValidateFailedError(f"{ex}") diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 85dc6ef51d1a1d..02a53708be9c85 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -44,54 +45,43 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + data = {"model": "embo-01", "texts": texts, "type": "db"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -119,9 +109,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001: @@ -148,25 +138,17 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -178,10 +160,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +171,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py index 01ed8010de873c..da60bd7661d597 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -7,14 +7,19 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None @@ -27,5 +32,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py index f1d825f6c6f042..7f9db8da1c1ddf 100644 --- a/api/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -8,7 +8,6 @@ class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index c233596637fa21..3ea46c2967e19c 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -30,11 +30,17 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,50 +55,50 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 4096)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ @@ -107,19 +113,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -129,14 +129,16 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -162,21 +164,26 @@ def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[ if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -186,11 +193,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -201,12 +209,7 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -220,9 +223,9 @@ def get_tool_call(tool_name: str): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -244,9 +247,9 @@ def get_tool_call(tool_name: str): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -255,21 +258,21 @@ def get_tool_call(tool_name: str): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -277,19 +280,18 @@ def get_tool_call(tool_name: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -305,26 +307,21 @@ def get_tool_call(tool_name: str): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index 5654ae1459cc16..4995e235f54c69 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -8,7 +8,6 @@ class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/novita/llm/llm.py b/api/core/model_runtime/model_providers/novita/llm/llm.py index 7662bf914a1e42..23367ed1b4309e 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llm.py +++ b/api/core/model_runtime/model_providers/novita/llm/llm.py @@ -8,20 +8,25 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - - credentials['endpoint_url'] = "https://api.novita.ai/v3/openai" - credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' } + credentials["endpoint_url"] = "https://api.novita.ai/v3/openai" + credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"} return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) self._add_custom_parameters(credentials, model) @@ -29,21 +34,36 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' + credentials["mode"] = "chat" - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/novita/novita.py b/api/core/model_runtime/model_providers/novita/novita.py index f1b72246057c6d..76a75b01e27e01 100644 --- a/api/core/model_runtime/model_providers/novita/novita.py +++ b/api/core/model_runtime/model_providers/novita/novita.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `meta-llama/llama-3-8b-instruct` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials=credentials - ) + model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index bc42eaca658bac..4d3747dc842ea4 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -21,31 +21,36 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { - 'fuyu-8b': 'vlm/adept/fuyu-8b', - 'mistralai/mistral-large': '', - 'mistralai/mixtral-8x7b-instruct-v0.1': '', - 'mistralai/mixtral-8x22b-instruct-v0.1': '', - 'google/gemma-7b': '', - 'google/codegemma-7b': '', - 'snowflake/arctic':'', - 'meta/llama2-70b': '', - 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '', - 'meta/llama-3.1-8b-instruct': '', - 'meta/llama-3.1-70b-instruct': '', - 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '', - 'nvidia/nemotron-4-340b-instruct': '', - 'microsoft/phi-3-medium-128k-instruct':'', - 'microsoft/phi-3-mini-128k-instruct':'' + "fuyu-8b": "vlm/adept/fuyu-8b", + "mistralai/mistral-large": "", + "mistralai/mixtral-8x7b-instruct-v0.1": "", + "mistralai/mixtral-8x22b-instruct-v0.1": "", + "google/gemma-7b": "", + "google/codegemma-7b": "", + "snowflake/arctic": "", + "meta/llama2-70b": "", + "meta/llama3-8b-instruct": "", + "meta/llama3-70b-instruct": "", + "meta/llama-3.1-8b-instruct": "", + "meta/llama-3.1-70b-instruct": "", + "meta/llama-3.1-405b-instruct": "", + "google/recurrentgemma-2b": "", + "nvidia/nemotron-4-340b-instruct": "", + "microsoft/phi-3-medium-128k-instruct": "", + "microsoft/phi-3-mini-128k-instruct": "", } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) prompt_messages = self._transform_prompt_messages(prompt_messages) stop = [] @@ -60,16 +65,14 @@ def _transform_prompt_messages(self, prompt_messages: list[PromptMessage]) -> li for i, p in enumerate(prompt_messages): if isinstance(p, UserPromptMessage) and isinstance(p.content, list): content = p.content - content_text = '' + content_text = "" for prompt_content in content: if prompt_content.type == PromptMessageContentType.TEXT: content_text += prompt_content.data else: content_text += f' ' - prompt_message = UserPromptMessage( - content=content_text - ) + prompt_message = UserPromptMessage(content=content_text) prompt_messages[i] = prompt_message return prompt_messages @@ -78,15 +81,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._validate_credentials(model, credentials) def _add_custom_parameters(self, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - + credentials["mode"] = "chat" + if self.MODEL_SUFFIX_MAP[model]: - credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}' - credentials.pop('endpoint_url') + credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}" + credentials.pop("endpoint_url") else: - credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1' + credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1" - credentials['stream_mode_delimiter'] = '\n' + credentials["stream_mode_delimiter"] = "\n" def _validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -97,72 +100,67 @@ def _validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -176,57 +174,51 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" if stream: - headers['Accept'] = 'text/event-stream' + headers["Accept"] = "text/event-stream" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -240,16 +232,10 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if not response.ok: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.py b/api/core/model_runtime/model_providers/nvidia/nvidia.py index e83f8badb57242..058fa003462585 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.py +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.py @@ -8,7 +8,6 @@ class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='mistralai/mixtral-8x7b-instruct-v0.1', - credentials=credentials - ) + model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 80c24b0555cd41..fabebc67ab0eeb 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) + return 1 / (1 + exp(-logit)) - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -60,9 +67,9 @@ def _invoke(self, model: str, credentials: dict, results = response.json() rerank_documents = [] - for result in results['rankings']: - index = result['index'] - logit = result['logit'] + for result in results["rankings"]: + index = result["index"] + logit = result["logit"] rerank_document = RerankDocument( index=index, text=docs[index], @@ -110,5 +117,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [requests.HTTPError], InvokeRateLimitError: [], InvokeAuthorizationError: [requests.HTTPError], - InvokeBadRequestError: [requests.RequestException] + InvokeBadRequestError: [requests.RequestException], } diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d404c..00cec265d5ed98 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Nvidia text embedding model. """ - api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings' - models: list[str] = ['NV-Embed-QA'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" + models: list[str] = ["NV-Embed-QA"] + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,32 +38,25 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': model, - 'input': texts[0], - 'input_type': 'query' - } + data = {"model": model, "input": texts[0], "input_type": "query"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +66,21 @@ def _invoke(self, model: str, credentials: dict, else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -117,30 +109,20 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -152,10 +134,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -166,7 +145,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py index f7b849fbe23b7f..6ff380bdd99c8b 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): """ Model class for NVIDIA NIM large language model. """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py index 25ab3e8e20f021..ad890ada22abc8 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -6,6 +6,5 @@ class NVIDIANIMProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 37787c459d80e9..ad5197a154c282 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -33,31 +33,29 @@ request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.command-r-plus", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"}, "chatRequest": { "apiFormat": "COHERE", - #"preambleOverride": "You are a helpful assistant.", - #"message": "Hello!", - #"chatHistory": [], + # "preambleOverride": "You are a helpful assistant.", + # "message": "Hello!", + # "chatHistory": [], "maxTokens": 600, "isStream": False, "frequencyPenalty": 0, "presencePenalty": 0, "temperature": 1, - "topP": 0.75 - } + "topP": 0.75, + }, } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + class OCILargeLanguageModel(LargeLanguageModel): # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm @@ -100,11 +98,17 @@ def _is_system_prompt_supported(self, model_id: str) -> bool: return False return feature["system"] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -118,22 +122,27 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - #print("model"+"*"*20) - #print(model) - #print("credentials"+"*"*20) - #print(credentials) - #print("model_parameters"+"*"*20) - #print(model_parameters) - #print("prompt_messages"+"*"*200) - #print(prompt_messages) - #print("tools"+"*"*20) - #print(tools) + # print("model"+"*"*20) + # print(model) + # print("credentials"+"*"*20) + # print(credentials) + # print("model_parameters"+"*"*20) + # print(model_parameters) + # print("prompt_messages"+"*"*200) + # print(prompt_messages) + # print("tools"+"*"*20) + # print(tools) # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -147,8 +156,13 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr return self._get_num_tokens_by_gpt2(prompt) - def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_characters( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -169,10 +183,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() @@ -192,11 +203,17 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -218,10 +235,12 @@ def _generate(self, model: str, credentials: dict, # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -230,12 +249,12 @@ def _generate(self, model: str, credentials: dict, else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") - #oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) compartment_id = oci_config["compartment_id"] client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) # call embedding model @@ -245,9 +264,9 @@ def _generate(self, model: str, credentials: dict, chat_history = [] system_prompts = [] - #if "meta.llama" in model: + # if "meta.llama" in model: # request_args["chatRequest"]["apiFormat"] = "GENERIC" - request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600) + request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600) request_args["chatRequest"].update(model_parameters) frequency_penalty = model_parameters.get("frequencyPenalty", 0) presence_penalty = model_parameters.get("presencePenalty", 0) @@ -267,7 +286,7 @@ def _generate(self, model: str, credentials: dict, if not valid_value: raise InvokeBadRequestError("Does not support function calling") if model.startswith("cohere"): - #print("run cohere " * 10) + # print("run cohere " * 10) for message in prompt_messages[:-1]: text = "" if isinstance(message.content, str): @@ -279,37 +298,37 @@ def _generate(self, model: str, credentials: dict, if isinstance(message, SystemPromptMessage): if isinstance(message.content, str): system_prompts.append(message.content) - args = {"apiFormat": "COHERE", - "preambleOverride": ' '.join(system_prompts), - "message": prompt_messages[-1].content, - "chatHistory": chat_history, } + args = { + "apiFormat": "COHERE", + "preambleOverride": " ".join(system_prompts), + "message": prompt_messages[-1].content, + "chatHistory": chat_history, + } request_args["chatRequest"].update(args) elif model.startswith("meta"): - #print("run meta " * 10) + # print("run meta " * 10) meta_messages = [] for message in prompt_messages: text = message.content meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]}) - args = {"apiFormat": "GENERIC", - "messages": meta_messages, - "numGenerations": 1, - "topK": -1} + args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1} request_args["chatRequest"].update(args) if stream: request_args["chatRequest"]["isStream"] = True - #print("final request" + "|" * 20) - #print(request_args) + # print("final request" + "|" * 20) + # print(request_args) response = client.chat(request_args) - #print(vars(response)) + # print(vars(response)) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -320,9 +339,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Bas :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.data.chat_response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text) # calculate num tokens prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) @@ -341,8 +358,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Bas return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -356,14 +374,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon events = response.data.events() for stream in events: chunk = json.loads(stream.data) - #print(chunk) - #chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} - - + # print(chunk) + # chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} - #for chunk in response: - #for part in chunk.parts: - #if part.function_call: + # for chunk in response: + # for part in chunk.parts: + # if part.function_call: # assistant_prompt_message.tool_calls = [ # AssistantPromptMessage.ToolCall( # id=part.function_call.name, @@ -376,9 +392,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon # ] if "finishReason" not in chunk: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if model.startswith("cohere"): if chunk["text"]: assistant_prompt_message.content += chunk["text"] @@ -389,10 +403,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: # calculate num tokens @@ -409,8 +420,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=index, message=assistant_prompt_message, finish_reason=str(chunk["finishReason"]), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -425,9 +436,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -457,5 +466,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/oci/oci.py b/api/core/model_runtime/model_providers/oci/oci.py index 11d67790a03510..e182d2d0439d77 100644 --- a/api/core/model_runtime/model_providers/oci/oci.py +++ b/api/core/model_runtime/model_providers/oci/oci.py @@ -8,7 +8,6 @@ class OCIGENAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,14 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `cohere.command-r-plus` model for validate, - model_instance.validate_credentials( - model='cohere.command-r-plus', - credentials=credentials - ) + model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex - - diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 5e0a85583e9589..df77db47d9fb53 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -21,29 +21,28 @@ request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.embed-english-light-v3.0", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"}, "truncate": "NONE", - "inputs": [""] + "inputs": [""], } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + class OCITextEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -62,14 +61,13 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: cutoff = int(len(text) * (np.floor(context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -80,26 +78,16 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=inputs[i: i + max_chunks] + model=model, credentials=credentials, texts=inputs[i : i + max_chunks] ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -125,6 +113,7 @@ def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> for text in texts: characters += len(text) return characters + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -135,11 +124,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -157,10 +142,12 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> # initialize client oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -169,7 +156,7 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") @@ -195,10 +182,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -209,7 +193,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -224,19 +208,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 42a588e3dd5df1..160eea0148702b 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -121,9 +121,7 @@ def get_num_tokens( text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -145,13 +143,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {ex.description}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {str(ex)}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def _generate( self, @@ -201,9 +195,7 @@ def _generate( if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") - data["messages"] = [ - self._convert_prompt_message_to_dict(m) for m in prompt_messages - ] + data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -216,14 +208,10 @@ def _generate( images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) + message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub( r"^data:image\/[a-zA-Z]+;base64,", "", @@ -235,24 +223,16 @@ def _generate( data["images"] = images # send a post request to validate the credentials - response = requests.post( - endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError( - f"API request failed with status code {response.status_code}: {response.text}" - ) + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") if stream: - return self._handle_generate_stream_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) def _handle_generate_response( self, @@ -292,9 +272,7 @@ def _handle_generate_response( completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response result = LLMResult( @@ -335,9 +313,7 @@ def create_final_llm_result_chunk( completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( model=model, @@ -394,15 +370,11 @@ def create_final_llm_result_chunk( completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2( - prompt_messages[0].content - ) + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=chunk_json["model"], @@ -439,17 +411,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) - image_data = re.sub( - r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -479,9 +445,7 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: return num_tokens - def get_customizable_model_schema( - self, model: str, credentials: dict - ) -> AIModelEntity: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -502,9 +466,7 @@ def get_customizable_model_schema( fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get("context_size", 4096) - ), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), }, parameter_rules=[ ParameterRule( @@ -568,9 +530,7 @@ def get_customizable_model_schema( en_US="Maximum number of tokens to predict when generating text. " "(Default: 128, -1 = infinite generation, -2 = fill context)" ), - default=( - 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 - ), + default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128), min=-2, max=int(credentials.get("max_tokens", 4096)), ), @@ -612,22 +572,23 @@ def get_customizable_model_schema( label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, help=I18nObject( - en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + en_US="Sets the size of the context window used to generate the next token. " "(Default: 2048)" ), default=2048, min=1, ), ParameterRule( - name='num_gpu', + name="num_gpu", label=I18nObject(en_US="GPU Layers"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to offload to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." - "As long as a model fits into one gpu it stays in one. " - "It does not set the number of GPU(s). "), + help=I18nObject( + en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). " + ), min=-1, - default=1 + default=1, ), ParameterRule( name="num_thread", @@ -688,8 +649,7 @@ def get_customizable_model_schema( label=I18nObject(en_US="Format"), type=ParameterType.STRING, help=I18nObject( - en_US="the format to return a response in." - " Currently the only accepted value is json." + en_US="the format to return a response in." " Currently the only accepted value is json." ), options=["json"], ), diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py index f8a17b98a0d677..115280193a5ed6 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.py +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -6,7 +6,6 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 8f7d54c5167255..60b85197be62d6 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -37,9 +37,9 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,15 +51,13 @@ def _invoke(self, model: str, credentials: dict, """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embed') + endpoint_url = urljoin(endpoint_url, "api/embed") # get model properties context_size = self._get_context_size(model, credentials) @@ -74,46 +72,34 @@ def _invoke(self, model: str, credentials: dict, if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) # Prepare the payload for the request payload = { - 'input': inputs, - 'model': model, + "input": inputs, + "model": model, } # Make the request to the OpenAI API response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300), - options={"use_mmap": "true"} + endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300), options={"use_mmap": "true"} ) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embeddings'] + embeddings = response_data["embeddings"] embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) used_tokens += embedding_used_tokens # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -135,19 +121,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -155,15 +137,15 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -179,10 +161,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -193,7 +172,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -221,10 +200,10 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 467a51daf2a278..2181bb4f08fd8f 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -22,7 +22,7 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } @@ -31,8 +31,8 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: openai_api_base = credentials["openai_api_base"].rstrip("/") credentials_kwargs["base_url"] = openai_api_base + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index dc85f7c9f2a33d..5950b77a96a015 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -39,16 +39,23 @@ """ + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,8 +71,8 @@ def _invoke(self, model: str, credentials: dict, """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -80,7 +87,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -91,26 +98,34 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -123,7 +138,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -135,9 +150,9 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -147,14 +162,21 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -167,25 +189,35 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -202,25 +234,30 @@ def _transform_completion_json_prompts(self, model: str, credentials: dict, break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -231,8 +268,8 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -262,14 +299,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError(f"Fine-tuned model {model} not found") # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -277,7 +314,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -286,7 +323,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -313,11 +350,11 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [model for model in remote_models if model.id.startswith("ft:")] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None for predefined_model_name, predefined_model in predefined_models_map.items(): @@ -329,30 +366,29 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, - pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + pricing=PriceConfig(input=0.003, output=0.006, unit=0.001, currency="USD"), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -374,23 +410,17 @@ def _generate(self, model: str, credentials: dict, extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - "include_usage": True - } - + extra_model_kwargs["stream_options"] = {"include_usage": True} + # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -398,8 +428,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm completion response @@ -412,9 +443,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Com assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -440,8 +469,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Com return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm completion stream response @@ -451,7 +481,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" prompt_tokens = 0 completion_tokens = 0 @@ -460,8 +490,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -474,14 +504,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -494,7 +522,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -504,7 +532,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -520,10 +548,17 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon yield final_chunk - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -562,22 +597,18 @@ def _chat_generate(self, model: str, credentials: dict, if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - 'include_usage': True - } + extra_model_kwargs["stream_options"] = {"include_usage": True} # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) @@ -596,9 +627,14 @@ def _chat_generate(self, model: str, credentials: dict, return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -619,10 +655,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -648,9 +681,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -660,7 +698,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None prompt_tokens = 0 completion_tokens = 0 @@ -670,8 +708,8 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -685,8 +723,11 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -708,7 +749,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -720,11 +761,10 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -735,7 +775,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -745,7 +785,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -753,8 +793,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -764,9 +803,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -777,21 +816,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -801,14 +838,11 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -821,7 +855,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp :param prompt_messages: prompt messages :return: cleaned prompt messages """ - checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] if model in checklist: # count how many user messages are there @@ -830,11 +864,16 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - item.data if item.type == PromptMessageContentType.TEXT else - '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' - for item in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) return prompt_messages @@ -851,19 +890,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -889,11 +922,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -902,8 +931,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -924,13 +952,14 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. if model == "chatgpt-4o-latest": @@ -969,10 +998,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -1011,37 +1040,37 @@ def _num_tokens_for_tools(self, encoding: tiktoken.Encoding, tools: list[PromptM """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -1049,26 +1078,26 @@ def _num_tokens_for_tools(self, encoding: tiktoken.Encoding, tools: list[PromptM def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -1077,16 +1106,13 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad26920..619044d808cdf6 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -14,9 +14,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -34,10 +32,10 @@ def _invoke(self, model: str, credentials: dict, # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] for text_chunk in chunks: moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) @@ -65,7 +63,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index 66efd4797f621a..175d7db73c46d4 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -9,7 +9,6 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials @@ -22,12 +21,9 @@ def validate_provider_credentials(self, credentials: Mapping) -> None: # Use `gpt-3.5-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index efbdd054f9457c..18f97e45f33bd8 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -12,9 +12,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -37,7 +35,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index e23a2edf87aefb..535d8388bc7d7d 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -18,9 +18,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +37,9 @@ def _invoke(self, model: str, credentials: dict, extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +56,9 @@ def _invoke(self, model: str, credentials: dict, enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -69,10 +67,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,10 +83,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -101,17 +93,9 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -152,17 +136,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,10 +159,12 @@ def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens @@ -197,10 +179,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -211,7 +190,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index afa5d4b88adecb..bfb443698ca8d7 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -14,8 +14,9 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -28,14 +29,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) # if streaming: - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -50,14 +49,13 @@ def validate_credentials(self, model: str, credentials: dict, user: Optional[str self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -71,31 +69,38 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca3778424..257dffa30de829 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from core.model_runtime.errors.invoke import ( @@ -35,10 +34,10 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] - } \ No newline at end of file + requests.exceptions.ReadTimeout, # Timeout + ], + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 6279125f46ffa0..75929af5904133 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -46,11 +46,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,8 +83,13 @@ def _invoke(self, model: str, credentials: dict, user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -99,93 +110,85 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''): - json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''): - json_result['object'] = 'text_completion' + if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": + json_result["object"] = "chat.completion" + elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": + json_result["object"] = "text_completion" - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type in ['function_call']: + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type in ["function_call"]: features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ['tool_call']: + elif function_calling_type in ["tool_call"]: features.append(ModelFeature.MULTI_TOOL_CALL) - stream_function_calling = credentials.get('stream_function_calling', 'supported') - if stream_function_calling == 'supported': + stream_function_calling = credentials.get("stream_function_calling", "supported") + if stream_function_calling == "supported": features.append(ModelFeature.STREAM_TOOL_CALL) - vision_support = credentials.get('vision_support', 'not_support') - if vision_support == 'support': + vision_support = credentials.get("vision_support", "not_support") + if vision_support == "support": features.append(ModelFeature.VISION) entity = AIModelEntity( @@ -195,43 +198,43 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -239,20 +242,20 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") @@ -260,10 +263,17 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -277,52 +287,47 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - extra_headers = credentials.get('extra_headers') + extra_headers = credentials.get("extra_headers") if extra_headers is not None: headers = { - **headers, - **extra_headers, + **headers, + **extra_headers, } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -336,16 +341,10 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") @@ -355,8 +354,9 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -366,11 +366,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -381,16 +382,12 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") @@ -406,10 +403,7 @@ def get_tool_call(tool_call_id: str): tool_call = AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name="", - arguments="" - ) + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) tools_calls.append(tool_call) @@ -434,10 +428,10 @@ def get_tool_call(tool_call_id: str): chunk = chunk.strip() if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue try: @@ -447,30 +441,31 @@ def get_tool_call(tool_call_id: str): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") assistant_message_tool_calls = None - if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': - assistant_message_tool_calls = delta.get('tool_calls', None) - elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': - assistant_message_tool_calls = [{ - 'id': 'tool_call_id', - 'type': 'function', - 'function': delta.get('function_call', {}) - }] + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": + assistant_message_tool_calls = delta.get("tool_calls", None) + elif ( + "function_call" in delta + and credentials.get("function_calling_type", "no_call") == "function_call" + ): + assistant_message_tool_calls = [ + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} + ] # assistant_message_function_call = delta.delta.function_call @@ -479,7 +474,7 @@ def get_tool_call(tool_call_id: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message @@ -490,9 +485,9 @@ def get_tool_call(tool_call_id: str): # reset tool calls tool_calls = [] full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -507,7 +502,7 @@ def get_tool_call(tool_call_id: str): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -518,47 +513,42 @@ def get_tool_call(tool_call_id: str): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason ) - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + def _handle_generate_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> LLMResult: response_json = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - if function_calling_type == 'tool_call': + if function_calling_type == "tool_call": assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': + elif function_calling_type == "function_call": assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") @@ -597,19 +587,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -618,11 +602,10 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] - elif function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] + elif function_calling_type == "function_call": function_call = message.tool_calls[0] message_dict["function_call"] = { "name": function_call.function.name, @@ -633,19 +616,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } - elif function_calling_type == 'function_call': - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif function_calling_type == "function_call": + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -654,8 +629,9 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O return message_dict - def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -667,7 +643,7 @@ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessag if isinstance(text, str): full_text = text else: - full_text = '' + full_text = "" for message_content in text: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) @@ -680,8 +656,13 @@ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessag return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: dict = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -700,10 +681,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -741,46 +722,44 @@ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -792,20 +771,17 @@ def _extract_response_tool_calls(self, for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", "") + arguments=response_tool_call.get("function", {}).get("arguments", ""), ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), - type=response_tool_call.get("type", ""), - function=function + id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -815,14 +791,11 @@ def _extract_response_function_call(self, response_function_call) \ tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get('name', ''), - arguments=response_function_call.get('arguments', '') + name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get('id', ''), - type="function", - function=function + id=response_function_call.get("id", ""), type="function", function=function ) return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 3445ebbaf752b3..ca6f1852872fed 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -6,6 +6,5 @@ class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 00702ba9367cf4..2e8b4ddd7234c2 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -14,9 +14,7 @@ class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): Model class for OpenAI Compatible Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 363054b084a69c..ab358cf70af9c8 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -27,9 +27,9 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,27 +39,25 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - + # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +68,6 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +75,7 @@ def _invoke(self, model: str, credentials: dict, if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -88,42 +85,25 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -145,45 +125,35 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -191,7 +161,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,20 +169,19 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -224,10 +193,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -238,7 +204,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde1167..b560afca39e9db 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -38,88 +38,115 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, - prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') - ], + server_url=credentials["server_url"], + model_name=model, + prompt_messages=[OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user")], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], + server_url=credentials["server_url"], prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,25 +157,27 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -159,73 +188,55 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +252,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 1c3f084207ff79..e754479ec0103a 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,40 +78,37 @@ def generate( "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if "use_cache" in model_parameters and type(model_parameters["use_cache"]) == bool: + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop if stop else [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: @@ -113,10 +116,10 @@ def generate( except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +128,71 @@ def generate( raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if choice.get('finish_reason'): - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - - yield message \ No newline at end of file + + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 4dbd0678e71f7f..00e583cc797dcb 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,16 +36,13 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +52,7 @@ def _invoke(self, model: str, credentials: dict, raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,21 +60,17 @@ def _invoke(self, model: str, credentials: dict, raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -104,9 +98,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,23 +113,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -147,10 +131,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -161,7 +142,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index e78ac4caf1da1e..71b5745f7d1fba 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -8,18 +8,23 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_credential(self, model: str, credentials: dict): - credentials['endpoint_url'] = "https://openrouter.ai/api/v1" - credentials['mode'] = self.get_model_mode(model).value - credentials['function_calling_type'] = 'tool_call' + credentials["endpoint_url"] = "https://openrouter.ai/api/v1" + credentials["mode"] = self.get_model_mode(model).value + credentials["function_calling_type"] = "tool_call" return - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -29,9 +34,17 @@ def validate_credentials(self, model: str, credentials: dict) -> None: return super().validate_credentials(model, credentials) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,8 +54,13 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return super().get_customizable_model_schema(model, credentials) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: self._update_credential(model, credentials) return super().get_num_tokens(model, credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py index 613f71deb1c806..2e59ab50598b8b 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.py +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -8,17 +8,13 @@ class OpenRouterProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='openai/gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="openai/gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex \ No newline at end of file + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py index c9116bf68538b4..89cac665aa5a08 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py @@ -13,11 +13,17 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -27,8 +33,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -46,8 +51,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -67,10 +73,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -101,10 +107,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://cloud.perfxlab.cn' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://cloud.perfxlab.cn" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py index 0854ef5185143d..450d22fb75943a 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py +++ b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py @@ -8,7 +8,6 @@ class PerfXCloudProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='Qwen2-72B-Instruct-GPTQ-Int4', - credentials=credentials - ) + model_instance.validate_credentials(model="Qwen2-72B-Instruct-GPTQ-Int4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 11d57e3749a8f1..d0522233e36f00 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -27,9 +27,9 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,30 +39,28 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - + # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -73,7 +71,6 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -81,7 +78,7 @@ def _invoke(self, model: str, credentials: dict, if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -91,42 +88,25 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -148,48 +128,38 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -197,7 +167,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -205,20 +175,19 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -230,10 +199,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -244,7 +210,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 29d8427d8ef231..915f6e0eefcd08 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -4,12 +4,6 @@ class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 31b81a829e0882..87c8bc4a91cda9 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -28,16 +28,22 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] + + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -48,39 +54,43 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if "system_prompt" in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"]: + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, " "format: {user_name}/{model_name}" + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -91,45 +101,44 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version, description): - if 'language model' in description.lower(): + if "language model" in description.lower(): return - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: + if ( + "temperature" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_p" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_k" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + ): raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules(model, credentials), ) return entity @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -140,15 +149,13 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) parameter_rules = [] input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in ["system_prompt", "prompt"] and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -157,28 +164,28 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -189,7 +196,7 @@ def _handle_generate_stream_response(self, for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -207,18 +214,13 @@ def _handle_generate_stream_response(self, index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' - ) + assistant_prompt_message = AssistantPromptMessage(content=output if output else "") if index < prediction_output_length: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -229,15 +231,17 @@ def _handle_generate_stream_response(self, yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -255,9 +259,7 @@ def _handle_generate_response(self, model: str, credentials: dict, prediction: P if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -275,21 +277,13 @@ def _handle_generate_response(self, model: str, credentials: dict, prediction: P @classmethod def _get_parameter_type(cls, param_type: str) -> str: - type_mapping = { - 'integer': 'int', - 'number': 'float', - 'boolean': 'boolean', - 'string': 'string' - } + type_mapping = {"integer": "int", "number": "float", "boolean": "boolean", "string": "string"} return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/replicate/replicate.py b/api/core/model_runtime/model_providers/replicate/replicate.py index 3a5c9b84a07b52..ca137579c96f2c 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.py +++ b/api/core/model_runtime/model_providers/replicate/replicate.py @@ -6,6 +6,5 @@ class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 0e4cdbf5bc13ca..f6b7754d74f4c1 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -13,32 +13,27 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -47,39 +42,35 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @@ -90,49 +81,45 @@ def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) # sort through the openapi schema to get the name of text, texts or inputs input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in ("text", "texts", "inputs"): text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] + ) -> list[list[float]]: + if text_input_key in ("text", "inputs"): embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run(replicate_model_version, input={text_input_key: text}) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -143,7 +130,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 3d4c5825af3887..2edd13d56d4d87 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -44,10 +44,11 @@ logger = logging.getLogger(__name__) -def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False): - """ + +def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False): + """ params: - predictor : Sagemaker Predictor + predictor : Sagemaker Predictor messages (List[Dict[str,Any]]): message list。 messages = [ {"role": "system", "content":"please answer in Chinese"}, @@ -55,19 +56,19 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto ] params (Dict[str,Any]): model parameters for LLM。 stream (bool): False by default。 - + response: result of inference if stream is False Iterator of Chunks if stream is True """ payload = { - "model" : params.get('model_name'), - "stop" : stop, + "model": params.get("model_name"), + "stop": stop, "messages": messages, - "stream" : stream, - "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)), - "temperature" : params.get('temperature', 0.1), - "top_p" : params.get('top_p', 0.9), + "stream": stream, + "max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)), + "temperature": params.get("temperature", 0.1), + "top_p": params.get("top_p", 0.9), } if not stream: @@ -77,36 +78,41 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto response_stream = predictor.predict_stream(payload) return response_stream + class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ - sagemaker_client: Any = None - sagemaker_sess : Any = None - predictor : Any = None - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: bytes) -> LLMResult: + sagemaker_client: Any = None + sagemaker_sess: Any = None + predictor: Any = None + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - resp_obj = json.loads(resp.decode('utf-8')) - resp_str = resp_obj.get('choices')[0].get('message').get('content') + resp_obj = json.loads(resp.decode("utf-8")) + resp_str = resp_obj.get("choices")[0].get("message").get("content") if len(resp_str) == 0: raise InvokeServerUnavailableError("Empty response") - assistant_prompt_message = AssistantPromptMessage( - content=resp_str, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -118,37 +124,43 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[bytes]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" buffer = "" for chunk_bytes in resp: - buffer += chunk_bytes.decode('utf-8') + buffer += chunk_bytes.decode("utf-8") last_idx = 0 - for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer): + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): try: data = json.loads(match.group(1).strip()) last_idx = match.span()[1] if "content" in data["choices"][0]["delta"]: chunk_content = data["choices"][0]["delta"]["content"] - assistant_prompt_message = AssistantPromptMessage( - content=chunk_content, - tool_calls=[] - ) - - if data["choices"][0]['finish_reason'] is not None: - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) + + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -157,8 +169,8 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes delta=LLMResultChunkDelta( index=0, message=assistant_prompt_message, - finish_reason=data["choices"][0]['finish_reason'], - usage=usage + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, ), ) else: @@ -166,10 +178,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes model=model, prompt_messages=prompt_messages, system_fingerprint=None, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message - ), + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), ) full_response += chunk_content @@ -179,11 +188,17 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes buffer = buffer[last_idx:] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -198,15 +213,17 @@ def _invoke(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ if not self.sagemaker_client: - access_key = credentials.get('access_key') - secret_key = credentials.get('secret_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -214,25 +231,26 @@ def _invoke(self, model: str, credentials: dict, sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) self.predictor = Predictor( - endpoint_name=credentials.get('sagemaker_endpoint'), + endpoint_name=credentials.get("sagemaker_endpoint"), sagemaker_session=sagemaker_session, serializer=serializers.JSONSerializer(), ) - - messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ] - response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream) + messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages] + response = inference( + predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream + ) if stream: if tools and len(tools) > 0: raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ @@ -247,19 +265,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -269,7 +281,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -282,8 +294,9 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: return message_dict - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -299,10 +312,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -339,8 +352,13 @@ def tokens(text: str): return num_tokens - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -381,89 +399,63 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = LLMMode.value_of(credentials["mode"]).value features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 6b7cfc210bc2a2..7e7614055c76b5 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -20,34 +20,36 @@ logger = logging.getLogger(__name__) + class SageMakerRerankModel(RerankModel): """ Model class for SageMaker rerank model. """ + sagemaker_client: Any = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -63,22 +65,21 @@ def _invoke(self, model: str, credentials: dict, line = 0 try: if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -86,22 +87,20 @@ def _invoke(self, model: str, credentials: dict, line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) for idx in range(len(scores)): - candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 3 rerank_documents = [] for idx, result in enumerate(candidate_docs): rerank_document = RerankDocument( - index=idx, - text=result.get('content'), - score=result.get('score', -100.0) + index=idx, text=result.get("content"), score=result.get("score", -100.0) ) if score_threshold is not None: @@ -110,13 +109,10 @@ def _invoke(self, model: str, credentials: dict, else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -137,7 +133,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -153,38 +149,24 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 6f3e02489fb0e5..042155b1522fef 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class SageMakerProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -17,27 +18,24 @@ def validate_provider_credentials(self, credentials: dict) -> None: """ pass -def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str: - ''' - return s3_uri of this file - ''' - s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3' - s3_client.put_object( - Body=file.read(), - Bucket=bucket, - Key=s3_key, - ContentType='audio/mp3' - ) + +def buffer_to_s3(s3_client: Any, file: IO[bytes], bucket: str, s3_prefix: str) -> str: + """ + return s3_uri of this file + """ + s3_key = f"{s3_prefix}{uuid.uuid4()}.mp3" + s3_client.put_object(Body=file.read(), Bucket=bucket, Key=s3_key, ContentType="audio/mp3") return s3_key -def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str: + +def generate_presigned_url(s3_client: Any, file: IO[bytes], bucket_name: str, s3_prefix: str, expiration=600) -> str: object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) try: - response = s3_client.generate_presigned_url('get_object', - Params={'Bucket': bucket_name, 'Key': object_key}, - ExpiresIn=expiration) + response = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket_name, "Key": object_key}, ExpiresIn=expiration + ) except Exception as e: print(f"Error generating presigned URL: {e}") return None - return response \ No newline at end of file + return response diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 8b57f182fe7771..6aa8c9995f68b4 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -19,16 +19,16 @@ logger = logging.getLogger(__name__) + class SageMakerSpeech2TextModel(Speech2TextModel): """ Model class for Xinference speech to text model. """ + sagemaker_client: Any = None - s3_client : Any = None + s3_client: Any = None - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -42,19 +42,20 @@ def _invoke(self, model: str, credentials: dict, try: if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) @@ -62,25 +63,21 @@ def _invoke(self, model: str, credentials: dict, self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - s3_prefix='dify/speech2text/' - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - bucket = credentials.get('audio_s3_cache_bucket') + s3_prefix = "dify/speech2text/" + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + bucket = credentials.get("audio_s3_cache_bucket") s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) - payload = { - "audio_s3_presign_uri" : s3_presign_url - } + payload = {"audio_s3_presign_uri": s3_presign_url} response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps(payload), - ContentType="application/json" + EndpointName=sagemaker_endpoint, Body=json.dumps(payload), ContentType="application/json" ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - asr_text = json_obj['text'] + asr_text = json_obj["text"] except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") return asr_text @@ -105,38 +102,24 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 4b2858b1a28228..d55144f8a79be9 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -10,21 +10,22 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel BATCH_SIZE = 20 -CONTEXT_SIZE=8192 +CONTEXT_SIZE = 8192 logger = logging.getLogger(__name__) + def batch_generator(generator, batch_size): while True: batch = list(itertools.islice(generator, batch_size)) @@ -32,33 +33,28 @@ def batch_generator(generator, batch_size): break yield batch + class SageMakerEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ + sagemaker_client: Any = None - def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]): response_model = sm_client.invoke_endpoint( EndpointName=endpoint_name, - Body=json.dumps( - { - "inputs": content_list, - "parameters": {}, - "is_query" : False, - "instruction" : '' - } - ), + Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - embeddings = json_obj['embeddings'] + embeddings = json_obj["embeddings"] return embeddings - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -72,25 +68,27 @@ def _invoke(self, model: str, credentials: dict, try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") line = 3 - truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + truncated_texts = [item[:CONTEXT_SIZE] for item in texts] batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) all_embeddings = [] @@ -105,18 +103,14 @@ def _invoke(self, model: str, credentials: dict, usage = self._calc_response_usage( model=model, credentials=credentials, - tokens=0 # It's not SAAS API, usage is meaningless + tokens=0, # It's not SAAS API, usage is meaningless ) line = 6 - return TextEmbeddingResult( - embeddings=all_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,10 +147,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -167,7 +158,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -175,40 +166,28 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 315b31fd858397..3dd5f8f64cfc25 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -22,89 +22,93 @@ logger = logging.getLogger(__name__) + class TTSModelType(Enum): PresetVoice = "PresetVoice" CloneVoice = "CloneVoice" CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" -class SageMakerText2SpeechModel(TTSModel): +class SageMakerText2SpeechModel(TTSModel): sagemaker_client: Any = None - s3_client : Any = None - comprehend_client : Any = None + s3_client: Any = None + comprehend_client: Any = None def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ pass - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] - - return map_dict.get(language_code, '<|zh|>') - - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + language_code = response["Languages"][0]["LanguageCode"] + + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -117,61 +121,55 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s :return: text translated to audio file """ if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + self.comprehend_client = boto3.client( + "comprehend", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') - - model_type = credentials.get('audio_model_type', 'PresetVoice') - prompt_text = credentials.get('prompt_text') - prompt_audio = credentials.get('prompt_audio') - instruct_text = credentials.get('instruct_text') - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - payload = self._build_tts_payload( - model_type, - content_text, - voice, - prompt_text, - prompt_audio, - instruct_text - ) + self.comprehend_client = boto3.client("comprehend") + + model_type = credentials.get("audio_model_type", "PresetVoice") + prompt_text = credentials.get("prompt_text") + prompt_audio = credentials.get("prompt_audio") + instruct_text = credentials.get("instruct_text") + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + payload = self._build_tts_payload(model_type, content_text, voice, prompt_text, prompt_audio, instruct_text) return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -187,23 +185,11 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _get_model_default_voice(self, model: str, credentials: dict) -> any: @@ -219,27 +205,27 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = 'CosyVoice' + audio_model_name = "CosyVoice" for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any: + def _tts_invoke_streaming(self, model_type: str, payload: dict, sagemaker_endpoint: str) -> any: """ _tts_invoke_streaming text2speech model @@ -250,38 +236,40 @@ def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint :return: text translated to audio file """ try: - lang_tag = '' + lang_tag = "" if model_type == TTSModelType.CloneVoice_CrossLingual.value: - lang_tag = payload.pop('lang_tag') - - word_limit = self._get_model_word_limit(model='', credentials={}) + lang_tag = payload.pop("lang_tag") + + word_limit = self._get_model_word_limit(model="", credentials={}) content_text = payload.get("tts_text") if len(content_text) > word_limit: split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) - sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ] + sentences = [f"{lang_tag}{s}" for s in split_sentences if len(s)] len_sent = len(sentences) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) - payloads = [ copy.deepcopy(payload) for i in range(len_sent) ] + payloads = [copy.deepcopy(payload) for i in range(len_sent)] for idx in range(len_sent): payloads[idx]["tts_text"] = sentences[idx] - futures = [ executor.submit( - self._invoke_sagemaker, - payload=payload, - endpoint=sagemaker_endpoint, - ) - for payload in payloads] + futures = [ + executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads + ] for index, future in enumerate(futures): resp = future.result() - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] else: resp = self._invoke_sagemaker(payload, sagemaker_endpoint) - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index a9ce7b98c35a39..c1868b6ad02b83 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -7,11 +7,17 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 683591581638e4..6f652e9d524e9c 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -16,39 +16,39 @@ class SiliconflowRerankModel(RerankModel): - - def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n, - "return_documents": True - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -57,7 +57,6 @@ def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], def validate_credentials(self, model: str, credentials: dict) -> None: try: - self._invoke( model=model, credentials=credentials, @@ -68,7 +67,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,5 +82,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [httpx.RemoteProtocolError], InvokeRateLimitError: [], InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] - } \ No newline at end of file + InvokeBadRequestError: [httpx.RequestError], + } diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index dd0eea362a5f83..e121ab8c7e4e2f 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -8,7 +8,6 @@ class SiliconflowProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py index 6ad3cab5873c69..8d1932863e09d9 100644 --- a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -8,9 +8,7 @@ class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): Model class for Siliconflow Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c58765cecb9a69..6cdf4933b47c73 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -10,20 +10,21 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): """ Model class for Siliconflow text embedding model. """ + def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, texts, user) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: self._add_custom_parameters(credentials) return super().get_num_tokens(model, credentials, texts) - + @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index d57766a87a0b07..25223e83405109 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -15,54 +15,35 @@ class SparkLLMClient: def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain model_api_configs = { - 'spark-lite': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-pro': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-pro-128k': { - 'version': 'pro-128k', - 'chat_domain': 'pro-128k' - }, - 'spark-max': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - }, - 'spark-4.0-ultra': { - 'version': 'v4.0', - 'chat_domain': '4.0Ultra' - } + "spark-lite": {"version": "v1.1", "chat_domain": "general"}, + "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, + "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, + "spark-4.0-ultra": {"version": "v4.0", "chat_domain": "4.0Ultra"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] - self.chat_domain = model_api_configs[model]['chat_domain'] + self.chat_domain = model_api_configs[model]["chat_domain"] - if model == 'spark-pro-128k': + if model == "spark-pro-128k": self.api_base = f"wss://{domain}/{endpoint}/{api_version}" else: self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret + urlparse(self.api_base).netloc, urlparse(self.api_base).path, self.api_base, api_key, api_secret ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: # generate timestamp by RFC1123 @@ -74,33 +55,29 @@ def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secr signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -109,86 +86,71 @@ def run(self, messages: list, user_id: str, ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put({"status_code": error.status_code, "error": error.resp_body.decode("utf-8")}) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps(self.gen_params(messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs)) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put({"status_code": 400, "error": f"Code: {code}, Error: {data['header']['message']}"}) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None) -> dict: data = { "header": { "app_id": self.app_id, # resolve this error message => $.header.uid' length must be less or equal than 32 - "uid": user_id[:32] if user_id else None - }, - "parameter": { - "chat": { - "domain": self.chat_domain - } + "uid": user_id[:32] if user_id else None, }, - "payload": { - "message": { - "text": messages - } - } + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - if 'data' not in content: + if "data" not in content: break yield content diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c72e9..0c42acf5aa6f87 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -25,12 +25,17 @@ class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -47,8 +52,13 @@ def _invoke(self, model: str, credentials: dict, # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +90,21 @@ def validate_credentials(self, model: str, credentials: dict) -> None: model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +119,7 @@ def _generate(self, model: str, credentials: dict, """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +129,33 @@ def _generate(self, model: str, credentials: dict, **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + {"role": prompt_message.role.value, "content": prompt_message.content} + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +168,7 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,9 +176,7 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -168,9 +194,15 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -183,12 +215,12 @@ def _handle_generate_stream_response(self, thread: threading.Thread, model: str, """ for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta if delta else "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -199,11 +231,7 @@ def _handle_generate_stream_response(self, thread: threading.Thread, model: str, yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) thread.join() @@ -216,9 +244,9 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +272,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -254,10 +282,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -277,5 +302,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 6f6ffc8faa9be3..dab666e4d092c6 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -30,11 +30,17 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,51 +55,51 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 1024)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.stepfun.com/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" - def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -106,10 +112,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -117,7 +120,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op "type": "image_url", "image_url": { "url": message_content.data, - } + }, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -127,14 +130,16 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -160,21 +165,26 @@ def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[ if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -184,11 +194,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -199,12 +210,7 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -218,9 +224,9 @@ def get_tool_call(tool_name: str): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -242,9 +248,9 @@ def get_tool_call(tool_name: str): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -253,21 +259,21 @@ def get_tool_call(tool_name: str): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -275,19 +281,18 @@ def get_tool_call(tool_name: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -303,26 +308,21 @@ def get_tool_call(tool_name: str): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py index 50b17392b54276..e1c41a91537cd1 100644 --- a/api/core/model_runtime/model_providers/stepfun/stepfun.py +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -8,7 +8,6 @@ class StepfunProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='step-1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="step-1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index b62b9860cb3486..9fd4a45f453489 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -69,8 +69,8 @@ class FlashRecognizer: """ response: request_id string - status Integer - message String + status Integer + message String audio_duration Integer flash_result Result Array @@ -81,16 +81,16 @@ class FlashRecognizer: Sentence: text String - start_time Integer - end_time Integer - speaker_id Integer + start_time Integer + end_time Integer + speaker_id Integer word_list Word Array Word: - word String - start_time Integer - end_time Integer - stable_flag: Integer + word String + start_time Integer + end_time Integer + stable_flag: Integer """ def __init__(self, appid, credential): @@ -100,13 +100,13 @@ def __init__(self, appid, credential): def _format_sign_string(self, param): signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" for t in param: - if 'appid' in t: + if "appid" in t: signstr += str(t[1]) break signstr += "?" for x in param: tmp = x - if 'appid' in x: + if "appid" in x: continue for t in tmp: signstr += str(t) @@ -121,10 +121,9 @@ def _build_header(self): return header def _sign(self, signstr, secret_key): - hmacstr = hmac.new(secret_key.encode('utf-8'), - signstr.encode('utf-8'), hashlib.sha1).digest() + hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest() s = base64.b64encode(hmacstr) - s = s.decode('utf-8') + s = s.decode("utf-8") return s def _build_req_with_signature(self, secret_key, params, header): @@ -138,14 +137,22 @@ def _build_req_with_signature(self, secret_key, params, header): def _create_query_arr(self, req): return { - 'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())), - 'engine_type': req.engine_type, 'voice_format': req.voice_format, - 'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id, - 'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty, - 'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc, - 'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info, - 'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword, - 'sentence_max_length': req.sentence_max_length + "appid": self.appid, + "secretid": self.credential.secret_id, + "timestamp": str(int(time.time())), + "engine_type": req.engine_type, + "voice_format": req.voice_format, + "speaker_diarization": req.speaker_diarization, + "hotword_id": req.hotword_id, + "customization_id": req.customization_id, + "filter_dirty": req.filter_dirty, + "filter_modal": req.filter_modal, + "filter_punc": req.filter_punc, + "convert_num_mode": req.convert_num_mode, + "word_info": req.word_info, + "first_channel_only": req.first_channel_only, + "reinforce_hotword": req.reinforce_hotword, + "sentence_max_length": req.sentence_max_length, } def recognize(self, req, data): diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py index 00ec5aa9c8202e..5b427663ca85b0 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py @@ -18,9 +18,7 @@ class TencentSpeech2TextModel(Speech2TextModel): - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -43,7 +41,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,10 +81,6 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - requests.exceptions.ConnectionError - ], - InvokeAuthorizationError: [ - CredentialsValidateFailedError - ] + InvokeConnectionError: [requests.exceptions.ConnectionError], + InvokeAuthorizationError: [CredentialsValidateFailedError], } diff --git a/api/core/model_runtime/model_providers/tencent/tencent.py b/api/core/model_runtime/model_providers/tencent/tencent.py index dd9f90bb474f4e..79c6f577b8d5ef 100644 --- a/api/core/model_runtime/model_providers/tencent/tencent.py +++ b/api/core/model_runtime/model_providers/tencent/tencent.py @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: """ try: model_instance = self.get_model_instance(ModelType.SPEECH2TEXT) - model_instance.validate_credentials( - model='tencent', - credentials=credentials - ) + model_instance.validate_credentials(model="tencent", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index bb802d407157bf..b96d43979ef54a 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -22,16 +22,21 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,12 +46,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -61,45 +76,45 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), - ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get("context_size", "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('temperature', 0.7)), + default=float(cred_with_endpoint.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('top_p', 1)), + default=float(cred_with_endpoint.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=TOP_K, label=I18nObject(en_US="Top K"), type=ParameterType.INT, - default=int(cred_with_endpoint.get('top_k', 50)), + default=int(cred_with_endpoint.get("top_k", 50)), min=-2147483647, max=2147483647, - precision=0 + precision=0, ), ParameterRule( name=REPETITION_PENALTY, label=I18nObject(en_US="Repetition Penalty"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('repetition_penalty', 1)), + default=float(cred_with_endpoint.get("repetition_penalty", 1)), min=-3.4, max=3.4, - precision=1 + precision=1, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -107,46 +122,49 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode type=ParameterType.INT, default=512, min=1, - max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + max=int(cred_with_endpoint.get("max_tokens_to_sample", 4096)), ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ], pricing=PriceConfig( - input=Decimal(cred_with_endpoint.get('input_price', 0)), - output=Decimal(cred_with_endpoint.get('output_price', 0)), - unit=Decimal(cred_with_endpoint.get('unit', 0)), - currency=cred_with_endpoint.get('currency', "USD") + input=Decimal(cred_with_endpoint.get("input_price", 0)), + output=Decimal(cred_with_endpoint.get("output_price", 0)), + unit=Decimal(cred_with_endpoint.get("unit", 0)), + currency=cred_with_endpoint.get("currency", "USD"), ), ) - if cred_with_endpoint['mode'] == 'chat': + if cred_with_endpoint["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif cred_with_endpoint['mode'] == 'completion': + elif cred_with_endpoint["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") return entity - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index ffce4794e7a8ad..aa4100a7c9b4d8 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -6,6 +6,5 @@ class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index fab18b41fd0487..8a50c7aa05f38c 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -21,7 +21,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -51,5 +51,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 6667d40440240f..72c319d395657f 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -45,11 +45,17 @@ class TongyiLargeLanguageModel(LargeLanguageModel): tokenizers = {} - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -65,8 +71,14 @@ def _invoke(self, model: str, credentials: dict, """ # invoke model without code wrapper return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -76,10 +88,10 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :param tools: tools for tool calling :return: """ - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') - if model == 'farui-plus': - model = 'qwen-farui-plus' + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") + if model == "farui-plus": + model = "qwen-farui-plus" if model in self.tokenizers: tokenizer = self.tokenizers[model] @@ -110,16 +122,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -138,18 +156,18 @@ def _generate(self, model: str, credentials: dict, mode = self.get_model_mode(model, credentials) - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = self._convert_tools(tools) + extra_model_kwargs["tools"] = self._convert_tools(tools) if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -157,23 +175,22 @@ def _generate(self, model: str, credentials: dict, model_schema = self.get_model_schema(model, credentials) if ModelFeature.VISION in (model_schema.features or []): - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) response = MultiModalConversation.call(**params, stream=stream) else: # nothing different between chat model and completion model in tongyi - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) - response = Generation.call(**params, - result_format='message', - stream=stream) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = Generation.call(**params, result_format="message", stream=stream) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -184,9 +201,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen :return: llm response """ if response.status_code != 200 and response.status_code != HTTPStatus.OK: - raise ServiceUnavailableError( - response.message - ) + raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=response.output.choices[0].message.content, @@ -205,9 +220,13 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen return result - def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: Generator[GenerationResponse, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator[GenerationResponse, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -217,7 +236,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" tool_calls = [] for index, response in enumerate(responses): if response.status_code != 200 and response.status_code != HTTPStatus.OK: @@ -228,22 +247,22 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, resp_finish_reason = response.output.choices[0].finish_reason - if resp_finish_reason is not None and resp_finish_reason != 'null': + if resp_finish_reason is not None and resp_finish_reason != "null": resp_content = response.output.choices[0].message.content assistant_prompt_message = AssistantPromptMessage( - content='', + content="", ) - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] elif resp_content: # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message - assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + assistant_prompt_message.content = resp_content.replace(full_text, "", 1) full_text = resp_content @@ -251,12 +270,11 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, message_tool_calls = [] for tool_call_obj in tool_calls: message_tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_obj['function']['name'], - type='function', + id=tool_call_obj["function"]["name"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call_obj['function']['name'], - arguments=tool_call_obj['function']['arguments'] - ) + name=tool_call_obj["function"]["name"], arguments=tool_call_obj["function"]["arguments"] + ), ) message_tool_calls.append(message_tool_call) @@ -270,26 +288,23 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=resp_finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, usage=usage + ), ) else: resp_content = response.output.choices[0].message.content if not resp_content: - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] continue # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content.replace(full_text, '', 1), + content=resp_content.replace(full_text, "", 1), ) full_text = resp_content @@ -297,10 +312,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -311,7 +323,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -356,16 +368,14 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], - rich_content: bool = False) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage], rich_content: bool = False + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -375,24 +385,28 @@ def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[Prom tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content if not rich_content else [{"text": prompt_message.content}], + } + ) elif isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, str): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content + if not rich_content + else [{"text": prompt_message.content}], + } + ) else: sub_messages = [] for message_content in prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -402,35 +416,25 @@ def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[Prom # convert image base64 data to file in /tmp image_url = self._save_base64_image_to_file(message_content.data) - sub_message_dict = { - "image": image_url - } + sub_message_dict = {"image": image_url} sub_messages.append(sub_message_dict) # resort sub_messages to ensure text is always at last - sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + sub_messages = sorted(sub_messages, key=lambda x: "text" in x) - tongyi_messages.append({ - 'role': 'user', - 'content': sub_messages - }) + tongyi_messages.append({"role": "user", "content": sub_messages}) elif isinstance(prompt_message, AssistantPromptMessage): content = prompt_message.content if not content: - content = ' ' - message = { - 'role': 'assistant', - 'content': content if not rich_content else [{"text": content}] - } + content = " " + message = {"role": "assistant", "content": content if not rich_content else [{"text": content}]} if prompt_message.tool_calls: - message['tool_calls'] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] + message["tool_calls"] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] tongyi_messages.append(message) elif isinstance(prompt_message, ToolPromptMessage): - tongyi_messages.append({ - "role": "tool", - "content": prompt_message.content, - "name": prompt_message.tool_call_id - }) + tongyi_messages.append( + {"role": "tool", "content": prompt_message.content, "name": prompt_message.tool_call_id} + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -445,7 +449,7 @@ def _save_base64_image_to_file(self, base64_image: str) -> str: :return: image file path """ # get mime type and encoded string - mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + mime_type, encoded_string = base64_image.split(",")[0].split(";")[0].split(":")[1], base64_image.split(",")[1] # save image to file temp_dir = tempfile.gettempdir() @@ -463,19 +467,18 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[dict]: """ tool_definitions = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] properties_definitions = {} for p_key, p_val in properties.items(): - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" properties_definitions[p_key] = { - 'description': desc, - 'type': p_val['type'], + "description": desc, + "type": p_val["type"], } tool_definition = { @@ -484,8 +487,8 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[dict]: "name": tool.name, "description": tool.description, "parameters": properties_definitions, - "required": required_properties - } + "required": required_properties, + }, } tool_definitions.append(tool_definition) @@ -517,5 +520,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 97dcb72f7c1b3e..5783d2e383e2de 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -46,7 +46,6 @@ def _invoke( used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -71,12 +70,8 @@ def _invoke( batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=used_tokens - ) - return TextEmbeddingResult( - embeddings=batched_embeddings, usage=usage, model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -108,16 +103,12 @@ def validate_credentials(self, model: str, credentials: dict) -> None: credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents( - credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] - ) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents( - credentials_kwargs: dict, model: str, texts: list[str] - ) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -145,7 +136,7 @@ def embed_documents( raise ValueError("Embedding data is missing in the response.") else: raise ValueError("Response output is missing or does not contain embeddings.") - + if response.usage and "total_tokens" in response.usage: embedding_used_tokens += response.usage["total_tokens"] else: @@ -153,9 +144,7 @@ def embed_documents( return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def _calc_response_usage( - self, model: str, credentials: dict, tokens: int - ) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.py b/api/core/model_runtime/model_providers/tongyi/tongyi.py index d5e25e6ecf87ae..a084512de9a885 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `qwen-turbo` model for validate, - model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="qwen-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 664b02cd92fc09..48a38897a87e9e 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -18,8 +18,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -31,14 +32,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -53,14 +52,13 @@ def validate_credentials(self, model: str, credentials: dict, user: Optional[str self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -82,15 +80,21 @@ def invoke_remote(content, v, api_key, cb, at, wl): else: sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) for sentence in sentences: - SpeechSynthesizer.call(model=v, sample_rate=16000, - api_key=api_key, - text=sentence.strip(), - callback=cb, - format=at, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) - - threading.Thread(target=invoke_remote, args=( - content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + SpeechSynthesizer.call( + model=v, + sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) + + threading.Thread( + target=invoke_remote, + args=(content_text, voice, credentials.get("dashscope_api_key"), callback, audio_type, word_limit), + ).start() while True: audio = audio_queue.get() @@ -112,16 +116,18 @@ def _process_sentence(sentence: str, credentials: dict, voice: str, audio_type: :param audio_type: audio file type :return: text translated to audio file """ - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + api_key=credentials.get("dashscope_api_key"), + text=sentence.strip(), + format=audio_type, + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() class Callback(ResultCallback): - def __init__(self, queue: Queue): self._queue = queue diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2e1a8..cf7e3f14be03ad 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -33,198 +33,223 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={}, stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={}, + stream=False, + ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during connection: {str(ex)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause TritonInference LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): - text += f'User: {item.content}' + text += f"User: {item.content}" elif isinstance(item, SystemPromptMessage): - text += f'System: {item.content}' + text += f"System: {item.content}" elif isinstance(item, AssistantPromptMessage): - text += f'Assistant: {item.content}' + text += f"Assistant: {item.content}" else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=int(credentials.get('context_length', 2048)), - default=min(512, int(credentials.get('context_length', 2048))), - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + max=int(credentials.get("context_length", 2048)), + default=min(512, int(credentials.get("context_length", 2048))), + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), parameter_rules=rules, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_length", 2048)), }, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if 'stream' in credentials and not bool(credentials['stream']) and stream: - raise ValueError(f'stream is not supported by model {model}') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + + if "stream" in credentials and not bool(credentials["stream"]) and stream: + raise ValueError(f"stream is not supported by model {model}") try: parameters = {} - if 'temperature' in model_parameters: - parameters['temperature'] = model_parameters['temperature'] - if 'top_p' in model_parameters: - parameters['top_p'] = model_parameters['top_p'] - if 'top_k' in model_parameters: - parameters['top_k'] = model_parameters['top_k'] - if 'presence_penalty' in model_parameters: - parameters['presence_penalty'] = model_parameters['presence_penalty'] - if 'frequency_penalty' in model_parameters: - parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + if "temperature" in model_parameters: + parameters["temperature"] = model_parameters["temperature"] + if "top_p" in model_parameters: + parameters["top_p"] = model_parameters["top_p"] + if "top_k" in model_parameters: + parameters["top_k"] = model_parameters["top_k"] + if "presence_penalty" in model_parameters: + parameters["presence_penalty"] = model_parameters["presence_penalty"] + if "frequency_penalty" in model_parameters: + parameters["frequency_penalty"] = model_parameters["frequency_penalty"] - response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ - 'text_input': self._convert_prompt_message_to_text(prompt_messages), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'parameters': { - 'stream': False, - **parameters + response = post( + str(URL(credentials["server_url"]) / "v2" / "models" / model / "generate"), + json={ + "text_input": self._convert_prompt_message_to_text(prompt_messages), + "max_tokens": model_parameters.get("max_tokens", 512), + "parameters": {"stream": False, **parameters}, }, - }, timeout=(10, 120)) + timeout=(10, 120), + ) response.raise_for_status() if response.status_code != 200: - raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') - + raise InvokeBadRequestError(f"Invoke failed with status code {response.status_code}, {response.text}") + if stream: - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) except Exception as ex: - raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> LLMResult: + raise InvokeConnectionError(f"An error occurred during connection: {str(ex)}") + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) usage.completion_tokens = self._get_num_tokens_by_gpt2(text) return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text - ), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> Generator: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -233,13 +258,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=text - ), - usage=usage - ) + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text), usage=usage), ) @property @@ -253,15 +272,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - ], - InvokeRateLimitError: [ - ], - InvokeAuthorizationError: [ - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ValueError], + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab6e35..d85f7c82e7db71 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 13b73181e95ffb..47ebaccd84ab8a 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping import openai @@ -20,13 +19,13 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: Transform credentials to kwargs for model instance :param credentials: - :return: + :return: """ credentials_kwargs = { - "api_key": credentials['upstage_api_key'], + "api_key": credentials["upstage_api_key"], "base_url": "https://api.upstage.ai/v1/solar", "timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0), - "max_retries": 1 + "max_retries": 1, } return credentials_kwargs @@ -53,5 +52,3 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] openai.APIError, ], } - - diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index d1ed4619d6bbbf..1014b53f39ef62 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -36,15 +36,23 @@ """ + class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ - Model class for Upstage large language model. + Model class for Upstage large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -67,15 +75,25 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, - model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] self._transform_chat_json_prompts( model=model, @@ -86,9 +104,9 @@ def _code_block_mode_wrapper(self, stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -98,15 +116,23 @@ def _code_block_mode_wrapper(self, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ - Transform json prompts + Transform json prompts """ if stop is None: stop = [] @@ -117,20 +143,29 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): prompt_messages[0] = SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=UPSTAGE_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: - prompt_messages.insert(0, SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=UPSTAGE_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -155,30 +190,31 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = OpenAI(**credentials_kwargs) client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=10, - stream=False + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) extra_model_kwargs = {} if tools: - extra_model_kwargs["functions"] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: extra_model_kwargs["stop"] = stop @@ -198,10 +234,15 @@ def _chat_generate(self, model: str, credentials: dict, if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -222,10 +263,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -251,9 +289,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -263,7 +306,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None prompt_tokens = 0 completion_tokens = 0 @@ -273,8 +316,8 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -288,8 +331,11 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -311,7 +357,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -323,11 +369,10 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -338,7 +383,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -348,7 +393,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -356,8 +401,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -367,9 +411,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -380,21 +424,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -404,14 +446,11 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -429,19 +468,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -467,11 +500,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -483,16 +512,17 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for solar with Huggingface Solar tokenizer. - Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer + Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer """ tokenizer = self._get_tokenizer() - tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> - tokens_prefix = 1 # <|startoftext|> - tokens_suffix = 3 # <|im_start|>assistant\n + tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> + tokens_prefix = 1 # <|startoftext|> + tokens_suffix = 3 # <|im_start|>assistant\n num_tokens = 0 num_tokens += tokens_prefix @@ -502,10 +532,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "tool_calls": @@ -538,37 +568,37 @@ def _num_tokens_for_tools(self, tokenizer: Tokenizer, tools: list[PromptMessageT """ num_tokens = 0 for tool in tools: - num_tokens += len(tokenizer.encode('type')) - num_tokens += len(tokenizer.encode('function')) + num_tokens += len(tokenizer.encode("type")) + num_tokens += len(tokenizer.encode("function")) # calculate num tokens for function object - num_tokens += len(tokenizer.encode('name')) + num_tokens += len(tokenizer.encode("name")) num_tokens += len(tokenizer.encode(tool.name)) - num_tokens += len(tokenizer.encode('description')) + num_tokens += len(tokenizer.encode("description")) num_tokens += len(tokenizer.encode(tool.description)) parameters = tool.parameters - num_tokens += len(tokenizer.encode('parameters')) - if 'title' in parameters: - num_tokens += len(tokenizer.encode('title')) + num_tokens += len(tokenizer.encode("parameters")) + if "title" in parameters: + num_tokens += len(tokenizer.encode("title")) num_tokens += len(tokenizer.encode(parameters.get("title"))) - num_tokens += len(tokenizer.encode('type')) + num_tokens += len(tokenizer.encode("type")) num_tokens += len(tokenizer.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(tokenizer.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(tokenizer.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(tokenizer.encode(key)) for field_key, field_value in value.items(): num_tokens += len(tokenizer.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(tokenizer.encode(enum_field)) else: num_tokens += len(tokenizer.encode(field_key)) num_tokens += len(tokenizer.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(tokenizer.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(tokenizer.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(tokenizer.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 05ae8665d65bdd..edd4a36d98f587 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -18,6 +18,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): """ Model class for Upstage text embedding model. """ + def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") @@ -53,9 +54,9 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N for i, text in enumerate(texts): token = tokenizer.encode(text, add_special_tokens=False).tokens for j in range(0, len(token), context_size): - tokens += [token[j:j+context_size]] + tokens += [token[j : j + context_size]] indices += [i] - + batched_embeddings = [] _iter = range(0, len(tokens), max_chunks) @@ -63,20 +64,20 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i:i+max_chunks], + texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch - + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) - + for i in range(len(texts)): _result = results[i] if len(_result) == 0: @@ -91,15 +92,11 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: tokenizer = self._get_tokenizer() """ @@ -122,7 +119,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int total_num_tokens += len(tokenized_text) return total_num_tokens - + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -137,16 +134,13 @@ def validate_credentials(self, model: str, credentials: Mapping) -> None: client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model :param model: model name @@ -155,17 +149,19 @@ def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], :param extra_model_kwargs: extra model kwargs :return: embeddings and used tokens """ - response = client.embeddings.create( - model=model, - input=texts, - **extra_model_kwargs - ) + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": + return ( + [ + list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) + for embedding in response.data + ], + response.usage.total_tokens, + ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': - return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens) - return [data.embedding for data in response.data], response.usage.total_tokens - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -176,10 +172,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em :return: usage """ input_price_info = self.get_price( - model=model, - credentials=credentials, - tokens=tokens, - price_type=PriceType.INPUT + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT ) usage = EmbeddingUsage( @@ -189,7 +182,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/upstage.py b/api/core/model_runtime/model_providers/upstage/upstage.py index 56c91c00618922..e45d4aae19eb6c 100644 --- a/api/core/model_runtime/model_providers/upstage/upstage.py +++ b/api/core/model_runtime/model_providers/upstage/upstage.py @@ -8,7 +8,6 @@ class UpstageProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,10 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model="solar-1-mini-chat", - credentials=credentials - ) + model_instance.validate_credentials(model="solar-1-mini-chat", credentials=credentials) except CredentialsValidateFailedError as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e except Exception as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e - diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index af6ec3937c8e25..ecb22e21bd1006 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -49,12 +49,17 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -74,8 +79,16 @@ def _invoke(self, model: str, credentials: dict, # invoke Gemini model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate_anthropic( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke Anthropic large language model @@ -92,7 +105,7 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] - token = '' + token = "" # get access token from service account credential if service_account_info: @@ -102,40 +115,32 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li token = credentials.token # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if 'opus' or 'claude-3-5-sonnet' in model: - location = 'us-east5' + if "opus" or "claude-3-5-sonnet" in model: + location = "us-east5" else: - location = 'us-central1' - + location = "us-central1" + # use access token to authenticate if token: - client = AnthropicVertex( - region=location, - project_id=project_id, - access_token=token - ) + client = AnthropicVertex(region=location, project_id=project_id, access_token=token) # When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable else: client = AnthropicVertex( - region=location, + region=location, project_id=project_id, ) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs + model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -143,8 +148,9 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li return self._handle_claude_response(model, credentials, response, prompt_messages) - def _handle_claude_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_claude_response( + self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -156,9 +162,7 @@ def _handle_claude_response(self, model: str, credentials: dict, response: Messa """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text) # calculate num tokens if response.usage: @@ -175,16 +179,18 @@ def _handle_claude_response(self, model: str, credentials: dict, response: Messa # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_claude_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -196,7 +202,7 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -217,18 +223,16 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) index = chunk.index yield LLMResultChunk( @@ -237,12 +241,14 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) except Exception as ex: raise InvokeError(str(ex)) - def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_claude_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -262,10 +268,7 @@ def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_toke # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -281,7 +284,7 @@ def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_toke total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -295,13 +298,13 @@ def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -323,10 +326,7 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -336,7 +336,7 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -345,16 +345,14 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) @@ -370,8 +368,13 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -384,7 +387,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -394,13 +397,10 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -416,14 +416,16 @@ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -435,20 +437,25 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -462,7 +469,7 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop @@ -494,26 +501,21 @@ def _generate(self, model: str, credentials: dict, else: history.append(content) - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - google_model = glm.GenerativeModel( - model_name=model, - system_instruction=system_instruction - ) + google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, - generation_config=glm.GenerationConfig( - **config_kwargs - ), + generation_config=glm.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, - tools=self._convert_tools_to_glm_tool(tools) if tools else None + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -521,8 +523,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -533,9 +536,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: glm :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.candidates[0].content.parts[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -554,8 +555,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: glm return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -568,9 +570,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index = -1 for chunk in response: for part in chunk.candidates[0].content.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -579,35 +579,31 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - - if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + + if not hasattr(chunk, "finish_reason") or not chunk.finish_reason: # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -615,8 +611,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -631,9 +627,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -658,7 +652,7 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) - if (isinstance(message.content, str)): + if isinstance(message.content, str): glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) else: parts = [] @@ -666,8 +660,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: if c.type == PromptMessageContentType.TEXT: parts.append(glm.Part.from_text(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) glm_content = glm.Content(role="user", parts=parts) return glm_content @@ -675,22 +669,33 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: if message.content: glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) if message.tool_calls: - glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))]) + glm_content = glm.Content( + role="model", + parts=[ + glm.Part.from_function_response( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ], + ) return glm_content elif isinstance(message, ToolPromptMessage): - glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))]) + glm_content = glm.Content( + role="function", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], + ) return glm_content else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -702,25 +707,20 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke emd = gml.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -736,5 +736,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 2404ba589431a7..519373a7f31a35 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -29,9 +29,9 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): Model class for Vertex AI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,23 +51,12 @@ def _invoke(self, model: str, credentials: dict, client = VertexTextEmbeddingModel.from_pretrained(model) - embeddings_batch, embedding_used_tokens = self._embedding_invoke( - client=client, - texts=texts - ) + embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts) # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=embedding_used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens) - return TextEmbeddingResult( - embeddings=embeddings_batch, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -115,15 +104,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = VertexTextEmbeddingModel.from_pretrained(model) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'] - ) + self._embedding_invoke(model=model, client=client, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore """ Invoke embedding model @@ -154,10 +139,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -168,14 +150,14 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -183,15 +165,15 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py index 3cbfb088d12536..466a86fd36a181 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-1.0-pro-002` model for validate, - model_instance.validate_credentials( - model='gemini-1.0-pro-002', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-1.0-pro-002", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index a4d89dabcbc076..d6f1356651e7b3 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -69,31 +69,26 @@ def is_compatible_with_legacy(credentials: dict) -> bool: def from_credentials(cls, credentials): """Initialize the client using the credentials provided.""" args = { - "base_url": credentials['api_endpoint_host'], - "region": credentials['volc_region'], + "base_url": credentials["api_endpoint_host"], + "region": credentials["volc_region"], } if credentials.get("auth_method") == "api_key": args = { **args, - "api_key": credentials['volc_api_key'], + "api_key": credentials["volc_api_key"], } else: args = { **args, - "ak": credentials['volc_access_key_id'], - "sk": credentials['volc_secret_access_key'], + "ak": credentials["volc_access_key_id"], + "sk": credentials["volc_secret_access_key"], } if cls.is_compatible_with_legacy(credentials): - args = { - **args, - "base_url": DEFAULT_V3_ENDPOINT - } + args = {**args, "base_url": DEFAULT_V3_ENDPOINT} - client = ArkClientV3( - **args - ) - client.endpoint_id = credentials['endpoint_id'] + client = ArkClientV3(**args) + client.endpoint_id = credentials["endpoint_id"] return client @staticmethod @@ -107,54 +102,48 @@ def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - content.append(ChatCompletionContentPartTextParam( - text=message_content.text, - type='text', - )) + content.append( + ChatCompletionContentPartTextParam( + text=message_content.text, + type="text", + ) + ) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append(ChatCompletionContentPartImageParam( - image_url=ImageURL( - url=image_data, - detail=message_content.detail.value, - ), - type='image_url', - )) - message_dict = ChatCompletionUserMessageParam( - role='user', - content=content - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type="image_url", + ) + ) + message_dict = ChatCompletionUserMessageParam(role="user", content=content) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = ChatCompletionAssistantMessageParam( content=message.content, - role='assistant', - tool_calls=None if not message.tool_calls else [ + role="assistant", + tool_calls=None + if not message.tool_calls + else [ ChatCompletionMessageToolCallParam( id=call.id, - function=Function( - name=call.function.name, - arguments=call.function.arguments - ), - type='function' - ) for call in message.tool_calls - ] + function=Function(name=call.function.name, arguments=call.function.arguments), + type="function", + ) + for call in message.tool_calls + ], ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = ChatCompletionSystemMessageParam( - content=message.content, - role='system' - ) + message_dict = ChatCompletionSystemMessageParam(content=message.content, role="system") elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = ChatCompletionToolMessageParam( - content=message.content, - role='tool', - tool_call_id=message.tool_call_id + content=message.content, role="tool", tool_call_id=message.tool_call_id ) else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -164,23 +153,25 @@ def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam @staticmethod def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: return ChatCompletionToolParam( - type='function', + type="function", function=FunctionDefinition( name=message.name, description=message.description, parameters=message.parameters, - ) + ), ) - def chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> ChatCompletion: + def chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: """Block chat""" return self.ark.chat.completions.create( model=self.endpoint_id, @@ -194,15 +185,17 @@ def chat(self, messages: list[PromptMessage], temperature=temperature, ) - def stream_chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> Generator[ChatCompletionChunk]: + def stream_chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: """Stream chat""" chunks = self.ark.chat.completions.create( stream=True, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 1978c11680f597..025b1ed6d2bae1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -25,12 +25,12 @@ def set_endpoint_id(self, endpoint_id: str): self.endpoint_id = endpoint_id @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] + def from_credential(cls, credentials: dict) -> "MaaSClient": + host = credentials["api_endpoint_host"] + region = credentials["volc_region"] + ak = credentials["volc_access_key_id"] + sk = credentials["volc_secret_access_key"] + endpoint_id = credentials["endpoint_id"] client = cls(host, region) client.set_endpoint_id(endpoint_id) @@ -40,8 +40,8 @@ def from_credential(cls, credentials: dict) -> 'MaaSClient': def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + "parameters": params, + "messages": [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], **extra_model_kwargs, } if not stream: @@ -55,9 +55,7 @@ def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extr ) def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } + req = {"input": texts} return super().embeddings(self.endpoint_id, req) @staticmethod @@ -65,49 +63,40 @@ def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + message_dict = {"role": ChatRole.USER, "content": message.content} else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + raise ValueError("Content object type only support image_url") elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + { + "type": "image_url", + "image_url": { + "url": "", + "image_bytes": image_data, + "detail": message_content.detail, + }, } - }) + ) - message_dict = {'role': ChatRole.USER, 'content': content} + message_dict = {"role": ChatRole.USER, "content": content} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} + message_dict = {"role": ChatRole.ASSISTANT, "content": message.content} if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict["tool_calls"] = [ + {"name": call.function.name, "arguments": call.function.arguments} for call in message.tool_calls ] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = {"role": ChatRole.SYSTEM, "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = {"role": ChatRole.FUNCTION, "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -130,5 +119,5 @@ def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): "name": tool.name, "description": tool.description, "parameters": tool.parameters, - } + }, } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 21ffaf1258fefd..8b9c3462652027 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -102,43 +102,43 @@ class ServiceNotOpen(MaasException): AuthErrors = { - 'SignatureDoesNotMatch': SignatureDoesNotMatch, - 'MissingAuthenticationHeader': MissingAuthenticationHeader, - 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, - 'AuthenticationExpire': AuthenticationExpire, - 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatch, + "MissingAuthenticationHeader": MissingAuthenticationHeader, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, + "AuthenticationExpire": AuthenticationExpire, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, } BadRequestErrors = { - 'MissingParameter': MissingParameter, - 'InvalidParameter': InvalidParameter, - 'EndpointIsInvalid': EndpointIsInvalid, - 'EndpointIsNotEnable': EndpointIsNotEnable, - 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, - 'ReqTextExistRisk': ReqTextExistRisk, - 'RespTextExistRisk': RespTextExistRisk, - 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, - 'ServiceNotOpen': ServiceNotOpen, + "MissingParameter": MissingParameter, + "InvalidParameter": InvalidParameter, + "EndpointIsInvalid": EndpointIsInvalid, + "EndpointIsNotEnable": EndpointIsNotEnable, + "ModelNotSupportStreamMode": ModelNotSupportStreamMode, + "ReqTextExistRisk": ReqTextExistRisk, + "RespTextExistRisk": RespTextExistRisk, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, + "ServiceNotOpen": ServiceNotOpen, } RateLimitErrors = { - 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, - 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, - 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceeded, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, } ServerUnavailableErrors = { - 'InternalServiceError': InternalServiceError, - 'EndpointIsPending': EndpointIsPending, - 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, + "InternalServiceError": InternalServiceError, + "EndpointIsPending": EndpointIsPending, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, } ConnectionErrors = { - 'ClientSDKRequestError': ClientSDKRequestError, - 'RequestTimeout': RequestTimeout, - 'ServiceConnectionTimeout': ServiceConnectionTimeout, - 'ServiceConnectionRefused': ServiceConnectionRefused, - 'ServiceConnectionClosed': ServiceConnectionClosed, + "ClientSDKRequestError": ClientSDKRequestError, + "RequestTimeout": RequestTimeout, + "ServiceConnectionTimeout": ServiceConnectionTimeout, + "ServiceConnectionRefused": ServiceConnectionRefused, + "ServiceConnectionClosed": ServiceConnectionClosed, } ErrorCodeMap = { diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 64f342f16e936b..53f320736bcb68 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole from .maas import MaasException, MaasService -__all__ = ['MaasService', 'ChatRole', 'MaasException'] +__all__ = ["MaasService", "ChatRole", "MaasException"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 053432a089ee46..8f8139426c5fee 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -8,12 +8,12 @@ class MetaData: def __init__(self): - self.algorithm = '' - self.credential_scope = '' - self.signed_headers = '' - self.date = '' - self.region = '' - self.service = '' + self.algorithm = "" + self.credential_scope = "" + self.signed_headers = "" + self.date = "" + self.region = "" + self.service = "" def set_date(self, date): self.date = date @@ -36,23 +36,23 @@ def set_signed_headers(self, signed_headers): class SignResult: def __init__(self): - self.xdate = '' - self.xCredential = '' - self.xAlgorithm = '' - self.xSignedHeaders = '' - self.xSignedQueries = '' - self.xSignature = '' - self.xContextSha256 = '' - self.xSecurityToken = '' + self.xdate = "" + self.xCredential = "" + self.xAlgorithm = "" + self.xSignedHeaders = "" + self.xSignedQueries = "" + self.xSignature = "" + self.xContextSha256 = "" + self.xSecurityToken = "" - self.authorization = '' + self.authorization = "" def __str__(self): - return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) class Credentials: - def __init__(self, ak, sk, service, region, session_token=''): + def __init__(self, ak, sk, service, region, session_token=""): self.ak = ak self.sk = sk self.service = service @@ -72,73 +72,88 @@ def set_session_token(self, session_token): class Signer: @staticmethod def sign(request, credentials): - if request.path == '': - request.path = '/' - if request.method != 'GET' and not ('Content-Type' in request.headers): - request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + if request.path == "": + request.path = "/" + if request.method != "GET" and not ("Content-Type" in request.headers): + request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" format_date = Signer.get_current_format_date() - request.headers['X-Date'] = format_date - if credentials.session_token != '': - request.headers['X-Security-Token'] = credentials.session_token + request.headers["X-Date"] = format_date + if credentials.session_token != "": + request.headers["X-Security-Token"] = credentials.session_token md = MetaData() - md.set_algorithm('HMAC-SHA256') + md.set_algorithm("HMAC-SHA256") md.set_service(credentials.service) md.set_region(credentials.region) md.set_date(format_date[:8]) hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) - md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + md.set_credential_scope("/".join([md.date, md.region, md.service, "request"])) - signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_str = "\n".join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) - request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) + request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) return @staticmethod def hashed_canonical_request_v4(request, meta): body_hash = Util.sha256(request.body) - request.headers['X-Content-Sha256'] = body_hash + request.headers["X-Content-Sha256"] = body_hash signed_headers = {} for key in request.headers: - if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] - if 'host' in signed_headers: - v = signed_headers['host'] - if v.find(':') != -1: - split = v.split(':') + if "host" in signed_headers: + v = signed_headers["host"] + if v.find(":") != -1: + split = v.split(":") port = split[1] - if str(port) == '80' or str(port) == '443': - signed_headers['host'] = split[0] + if str(port) == "80" or str(port) == "443": + signed_headers["host"] = split[0] - signed_str = '' + signed_str = "" for key in sorted(signed_headers.keys()): - signed_str += key + ':' + signed_headers[key] + '\n' + signed_str += key + ":" + signed_headers[key] + "\n" - meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + meta.set_signed_headers(";".join(sorted(signed_headers.keys()))) - canonical_request = '\n'.join( - [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, - meta.signed_headers, body_hash]) + canonical_request = "\n".join( + [ + request.method, + Util.norm_uri(request.path), + Util.norm_query(request.query), + signed_str, + meta.signed_headers, + body_hash, + ] + ) return Util.sha256(canonical_request) @staticmethod def get_signing_secret_key_v4(sk, date, region, service): - date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + date = Util.hmac_sha256(bytes(sk, encoding="utf-8"), date) region = Util.hmac_sha256(date, region) service = Util.hmac_sha256(region, service) - return Util.hmac_sha256(service, 'request') + return Util.hmac_sha256(service, "request") @staticmethod def build_auth_header_v4(signature, meta, credentials): - credential = credentials.ak + '/' + meta.credential_scope - return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + credential = credentials.ak + "/" + meta.credential_scope + return ( + meta.algorithm + + " Credential=" + + credential + + ", SignedHeaders=" + + meta.signed_headers + + ", Signature=" + + signature + ) @staticmethod def get_current_format_date(): - return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") + return datetime.datetime.now(tz=pytz.timezone("UTC")).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py index 7271ae63fd7309..096339b3c74f83 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -6,7 +6,7 @@ from .auth import Signer -VERSION = 'v1.0.137' +VERSION = "v1.0.137" class Service: @@ -40,8 +40,9 @@ def get(self, api, params, doseq=0): Signer.sign(r, self.service_info.credentials) url = r.build(doseq) - resp = self.session.get(url, headers=r.headers, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.get( + url, headers=r.headers, timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout) + ) if resp.status_code == 200: return resp.text else: @@ -52,15 +53,19 @@ def post(self, api, params, form): raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.headers["Content-Type"] = "application/x-www-form-urlencoded" r.form = self.merge(api_info.form, form) r.body = urlencode(r.form, True) Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.form, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return resp.text else: @@ -71,21 +76,25 @@ def json(self, api, params, body): raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' + r.headers["Content-Type"] = "application/json" r.body = body Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return json.dumps(resp.json()) else: raise Exception(resp.text.encode("utf-8")) def put(self, url, file_path, headers): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: resp = self.session.put(url, headers=headers, data=f) if resp.status_code == 200: return True, resp.text.encode("utf-8") @@ -105,7 +114,7 @@ def prepare_request(self, api_info, params, doseq=0): params[key] = str(params[key]) elif type(params[key]) == list: if not doseq: - params[key] = ','.join(params[key]) + params[key] = ",".join(params[key]) connection_timeout = self.service_info.connection_timeout socket_timeout = self.service_info.socket_timeout @@ -117,8 +126,8 @@ def prepare_request(self, api_info, params, doseq=0): r.set_socket_timeout(socket_timeout) headers = self.merge(api_info.header, self.service_info.header) - headers['Host'] = self.service_info.host - headers['User-Agent'] = 'volc-sdk-python/' + VERSION + headers["Host"] = self.service_info.host + headers["User-Agent"] = "volc-sdk-python/" + VERSION r.set_headers(headers) query = self.merge(api_info.query, params) @@ -143,13 +152,13 @@ def merge(param1, param2): class Request: def __init__(self): - self.schema = '' - self.method = '' - self.host = '' - self.path = '' + self.schema = "" + self.method = "" + self.host = "" + self.path = "" self.headers = OrderedDict() self.query = OrderedDict() - self.body = '' + self.body = "" self.form = {} self.connection_timeout = 0 self.socket_timeout = 0 @@ -182,11 +191,11 @@ def set_socket_timeout(self, socket_timeout): self.socket_timeout = socket_timeout def build(self, doseq=0): - return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + return self.schema + "://" + self.host + self.path + "?" + urlencode(self.query, doseq) class ServiceInfo: - def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme="http"): self.host = host self.header = header self.credentials = credentials @@ -204,4 +213,4 @@ def __init__(self, method, path, query, form, header): self.header = header def __str__(self): - return 'method: ' + self.method + ', path: ' + self.path + return "method: " + self.method + ", path: " + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py index 7eb5fdfa9122a0..44f99599653c4b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -7,28 +7,28 @@ class Util: @staticmethod def norm_uri(path): - return quote(path).replace('%2F', '/').replace('+', '%20') + return quote(path).replace("%2F", "/").replace("+", "%20") @staticmethod def norm_query(params): - query = '' + query = "" for key in sorted(params.keys()): if type(params[key]) == list: for k in params[key]: - query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&" else: - query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&" query = query[:-1] - return query.replace('+', '%20') + return query.replace("+", "%20") @staticmethod def hmac_sha256(key, content): - return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + return hmac.new(key, bytes(content, encoding="utf-8"), hashlib.sha256).digest() @staticmethod def sha256(content): if isinstance(content, str) is True: - return hashlib.sha256(content.encode('utf-8')).hexdigest() + return hashlib.sha256(content.encode("utf-8")).hexdigest() else: return hashlib.sha256(content).hexdigest() @@ -36,8 +36,8 @@ def sha256(content): def to_hex(content): lst = [] for ch in content: - hv = hex(ch).replace('0x', '') + hv = hex(ch).replace("0x", "") if len(hv) == 1: - hv = '0' + hv + hv = "0" + hv lst.append(hv) return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py index 8b14d026d96795..3825fd65741ef5 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py @@ -43,9 +43,7 @@ def json_to_object(json_str, req_id=None): def gen_req_id(): - return datetime.now().strftime("%Y%m%d%H%M%S") + format( - random.randint(0, 2 ** 64 - 1), "020X" - ) + return datetime.now().strftime("%Y%m%d%H%M%S") + format(random.randint(0, 2**64 - 1), "020X") class SSEDecoder: @@ -53,13 +51,13 @@ def __init__(self, source): self.source = source def _read(self): - data = b'' + data = b"" for chunk in self.source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data @@ -67,13 +65,13 @@ def next(self): for chunk in self._read(): for line in chunk.splitlines(): # skip comment - if line.startswith(b':'): + if line.startswith(b":"): continue - if b':' in line: - field, value = line.split(b':', 1) + if b":" in line: + field, value = line.split(b":", 1) else: - field, value = line, b'' + field, value = line, b"" - if field == b'data' and len(value) > 0: + if field == b"data" and len(value) > 0: yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 3cbe9d9f099e83..01f15aec249e04 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -9,9 +9,7 @@ class MaasService(Service): def __init__(self, host, region, connection_timeout=60, socket_timeout=60): - service_info = self.get_service_info( - host, region, connection_timeout, socket_timeout - ) + service_info = self.get_service_info(host, region, connection_timeout, socket_timeout) self._apikey = None api_info = self.get_api_info() super().__init__(service_info, api_info) @@ -35,9 +33,7 @@ def get_service_info(host, region, connection_timeout, socket_timeout): def get_api_info(): api_info = { "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), - "embeddings": ApiInfo( - "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} - ), + "embeddings": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}), } return api_info @@ -52,9 +48,7 @@ def stream_chat(self, endpoint_id, req): try: req["stream"] = True - res = self._call( - endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True - ) + res = self._call(endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True) decoder = SSEDecoder(res) @@ -64,8 +58,7 @@ def iter_fn(): return try: - res = json_to_object( - str(data, encoding="utf-8"), req_id=req_id) + res = json_to_object(str(data, encoding="utf-8"), req_id=req_id) except Exception: raise @@ -95,8 +88,7 @@ def _request(self, endpoint_id, api, req, params={}): apikey = self._apikey try: - res = self._call(endpoint_id, api, req_id, params, - json.dumps(req).encode("utf-8"), apikey) + res = self._call(endpoint_id, api, req_id, params, json.dumps(req).encode("utf-8"), apikey) resp = dict_to_object(res.json()) if resp and isinstance(resp, dict): resp["req_id"] = req_id @@ -109,9 +101,9 @@ def _request(self, endpoint_id, api, req, params={}): def _validate(self, api, req_id): credentials_exist = ( - self.service_info.credentials is not None and - self.service_info.credentials.sk is not None and - self.service_info.credentials.ak is not None + self.service_info.credentials is not None + and self.service_info.credentials.sk is not None + and self.service_info.credentials.ak is not None ) if not self._apikey and not credentials_exist: @@ -150,15 +142,12 @@ def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=Fals raw = res.text.encode() res.close() try: - resp = json_to_object( - str(raw, encoding="utf-8"), req_id=req_id) + resp = json_to_object(str(raw, encoding="utf-8"), req_id=req_id) except Exception: raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, req_id - ) + raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) @@ -173,11 +162,13 @@ def __init__(self, code_n, code, message, req_id): self.req_id = req_id def __str__(self): - return ("Detailed exception information is listed below.\n" + - "req_id: {}\n" + - "code_n: {}\n" + - "code: {}\n" + - "message: {}").format(self.req_id, self.code_n, self.code, self.message) + return ( + "Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}" + ).format(self.req_id, self.code_n, self.code, self.message) def new_client_sdk_request_error(raw, req_id=""): @@ -189,25 +180,19 @@ def __init__(self, response, request_id) -> None: self.response = response self.request_id = request_id - def stream_to_file( - self, - file: str - ) -> None: + def stream_to_file(self, file: str) -> None: is_first = True - error_bytes = b'' + error_bytes = b"" with open(file, mode="wb") as f: for data in self.response: - if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + if len(error_bytes) > 0 or (is_first and '"error":' in str(data)): error_bytes += data else: f.write(data) if len(error_bytes) > 0: - resp = json_to_object( - str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, self.request_id - ) + resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 996c66e604fea2..98409ab872ee03 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -49,10 +49,17 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: if ArkClientV3.is_legacy(credentials): return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -71,12 +78,12 @@ def _validate_credentials_v2(credentials: dict) -> None: try: client.chat( { - 'max_new_tokens': 16, - 'temperature': 0.7, - 'top_p': 0.9, - 'top_k': 15, + "max_new_tokens": 16, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 15, }, - [UserPromptMessage(content='ping\nAnswer: ')], + [UserPromptMessage(content="ping\nAnswer: ")], ) except MaasException as e: raise CredentialsValidateFailedError(e.message) @@ -85,13 +92,22 @@ def _validate_credentials_v2(credentials: dict) -> None: def _validate_credentials_v3(credentials: dict) -> None: client = ArkClientV3.from_credentials(credentials) try: - client.chat(max_tokens=16, temperature=0.7, top_p=0.9, - messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + client.chat( + max_tokens=16, + temperature=0.7, + top_p=0.9, + messages=[UserPromptMessage(content="ping\nAnswer: ")], + ) except Exception as e: raise CredentialsValidateFailedError(e) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if ArkClientV3.is_legacy(credentials): return self._get_num_tokens_v2(prompt_messages) return self._get_num_tokens_v3(prompt_messages) @@ -100,8 +116,7 @@ def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + messages_dict = [MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -113,8 +128,7 @@ def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - ArkClientV3.convert_prompt_message(m) for m in messages] + messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -122,97 +136,108 @@ def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: return num_tokens - def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v2( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [ - MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools - ] - resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) + extra_model_kwargs["tools"] = [MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools] + resp = MaaSClient.wrap_exception(lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) def _handle_stream_chat_response() -> Generator: for index, r in enumerate(resp): - choices = r['choices'] + choices = r["choices"] if not choices: continue choice = choices[0] - message = choice['message'] + message = choice["message"] usage = None - if r.get('usage'): - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=r['usage']['prompt_tokens'], - completion_tokens=r['usage']['completion_tokens'] - ) + if r.get("usage"): + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=r["usage"]["prompt_tokens"], + completion_tokens=r["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] + content=message["content"] if message["content"] else "", tool_calls=[] ), usage=usage, - finish_reason=choice.get('finish_reason'), + finish_reason=choice.get("finish_reason"), ), ) def _handle_chat_response() -> LLMResult: - choices = resp['choices'] + choices = resp["choices"] if not choices: raise ValueError("No choices found") choice = choices[0] - message = choice['message'] + message = choice["message"] # parse tool calls tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: + if message["tool_calls"]: + for call in message["tool_calls"]: tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], + id=call["function"]["name"], + type=call["type"], function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + name=call["function"]["name"], arguments=call["function"]["arguments"] + ), ) tool_calls.append(tool_call) - usage = resp['usage'] + usage = resp["usage"] return LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', + content=message["content"] if message["content"] else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + ), ) if not stream: return _handle_chat_response() return _handle_stream_chat_response() - def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v3( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = ArkClientV3.from_credentials(credentials) req_params = get_v3_req_params(credentials, model_parameters, stop) if tools: - req_params['tools'] = tools + req_params["tools"] = tools def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: for chunk in chunks: @@ -225,14 +250,15 @@ def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Gene prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=choice.index, - message=AssistantPromptMessage( - content=choice.delta.content, - tool_calls=[] - ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens - ) if chunk.usage else None, + message=AssistantPromptMessage(content=choice.delta.content, tool_calls=[]), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.usage + else None, finish_reason=choice.finish_reason, ), ) @@ -248,9 +274,8 @@ def _handle_chat_response(resp: ChatCompletion) -> LLMResult: id=call.id, type=call.type, function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call.function.name, - arguments=call.function.arguments - ) + name=call.function.name, arguments=call.function.arguments + ), ) tool_calls.append(tool_call) @@ -262,10 +287,12 @@ def _handle_chat_response(resp: ChatCompletion) -> LLMResult: content=message.content if message.content else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), ) if not stream: @@ -277,72 +304,56 @@ def _handle_chat_response(resp: ChatCompletion) -> LLMResult: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ model_config = get_model_config(credentials) rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', - type=ParameterType.INT, - min=1, - default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + name="top_k", type=ParameterType.INT, min=1, default=1, label=I18nObject(zh_Hans="Top K", en_US="Top K") ), ParameterRule( - name='presence_penalty', + name="presence_penalty", type=ParameterType.FLOAT, - use_template='presence_penalty', + use_template="presence_penalty", label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", type=ParameterType.FLOAT, - use_template='frequency_penalty', + use_template="frequency_penalty", label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=model_config.properties.max_tokens, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ] @@ -352,9 +363,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index a882f68a36cc97..d8be14b0247698 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -16,138 +16,127 @@ class ModelConfig(BaseModel): configs: dict[str, ModelConfig] = { - 'Doubao-pro-4k': ModelConfig( + "Doubao-pro-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-4k': ModelConfig( + "Doubao-lite-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-32k': ModelConfig( + "Doubao-pro-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-32k': ModelConfig( + "Doubao-lite-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-128k': ModelConfig( + "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Doubao-lite-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Skylark2-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Llama3-8B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-8B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Llama3-70B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-70B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Moonshot-v1-8k': ModelConfig( + "Moonshot-v1-8k": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-32k': ModelConfig( + "Moonshot-v1-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-128k': ModelConfig( + "Moonshot-v1-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B': ModelConfig( + "GLM3-130B": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B-Fin': ModelConfig( + "GLM3-130B-Fin": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], + ), + "Mistral-7B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[] ), - 'Mistral-7B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), - features=[] - ) } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = configs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_tokens=int(credentials.get('max_tokens', 0)), - mode=LLMMode.value_of(credentials.get('mode', 'chat')), + context_size=int(credentials.get("context_size", 0)), + max_tokens=int(credentials.get("max_tokens", 0)), + mode=LLMMode.value_of(credentials.get("mode", "chat")), ), - features=[] + features=[], ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_prompt_tokens'] = model_configs.properties.context_size - req_params['max_new_tokens'] = model_configs.properties.max_tokens + req_params["max_prompt_tokens"] = model_configs.properties.context_size + req_params["max_new_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_new_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("top_k"): + req_params["top_k"] = model_parameters.get("top_k") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params -def get_v3_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_tokens'] = model_configs.properties.max_tokens + req_params["max_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 74cf26247cb233..ce4f0c3ab1960e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -11,20 +11,18 @@ class ModelConfig(BaseModel): ModelConfigs = { - 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=32) - ), + "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_chunks=int(credentials.get('max_chunks', 0)), + context_size=int(credentials.get("context_size", 0)), + max_chunks=int(credentials.get("max_chunks", 0)), ) ) return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index d54aeeb0b15f89..3cdcd2740c9b8c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -40,9 +40,9 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): Model class for VolcengineMaaS text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -57,37 +57,27 @@ def _invoke(self, model: str, credentials: dict, return self._generate_v3(model, credentials, texts, user) - def _generate_v2(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v2( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"]) - result = TextEmbeddingResult( - model=model, - embeddings=[v['embedding'] for v in resp['data']], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage) return result - def _generate_v3(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v3( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = ArkClientV3.from_credentials(credentials) resp = client.embeddings(texts) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp.usage.total_tokens) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=[v.embedding for v in resp.data], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage) return result @@ -120,13 +110,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except MaasException as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: raise CredentialsValidateFailedError(e) @@ -150,12 +140,12 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ model_config = get_model_config(credentials) model_properties = { ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, - ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks, } entity = AIModelEntity( model=model, @@ -165,10 +155,10 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_properties=model_properties, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -184,10 +174,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -198,7 +185,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index 017856bdde2e16..d72d1bd83a7e25 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -11,7 +11,7 @@ RateLimitReachedError, ) -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} baidu_access_tokens_lock = Lock() @@ -22,49 +22,46 @@ class BaiduAccessToken: def __init__(self, api_key: str) -> None: self.api_key = api_key - self.access_token = '' + self.access_token = "" self.expires = datetime.now() + timedelta(days=3) @staticmethod def _get_access_token(api_key: str, secret_key: str) -> str: """ - request access token from Baidu + request access token from Baidu """ try: response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", + headers={"Content-Type": "application/json", "Accept": "application/json"}, ) except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + raise InvalidAuthenticationError(f"Failed to get access token from Baidu: {e}") resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': + if "error" in resp: + if resp["error"] == "invalid_client": raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': + elif resp["error"] == "unknown_error": raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': + elif resp["error"] == "invalid_request": raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': + elif resp["error"] == "rate_limit_exceeded": raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') else: raise Exception(f'Unknown error: {resp["error_description"]}') - return resp['access_token'] + return resp["access_token"] @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. """ # loop up cache, remove expired access token @@ -98,49 +95,49 @@ def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': class _CommonWenxin: api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', - 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', - 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', - 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-3.5-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-3.5-8k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205", + "ernie-3.5-8k-1222": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222", + "ernie-3.5-4k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-3.5-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k", + "ernie-4.0-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-4.0-8k-latest": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-speed-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed", + "ernie-speed-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k", + "ernie-speed-appbuilder": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas", + "ernie-lite-8k-0922": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-lite-8k-0308": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k", + "ernie-character-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", + "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", + "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", + "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", + "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", } function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' + "ernie-bot", + "ernie-bot-8k", + "ernie-3.5-8k", + "ernie-3.5-8k-0205", + "ernie-3.5-8k-1222", + "ernie-3.5-4k-0205", + "ernie-3.5-128k", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview", + "yi_34b_chat", ] - api_key: str = '' - secret_key: str = '' + api_key: str = "" + secret_key: str = "" def __init__(self, api_key: str, secret_key: str): self.api_key = api_key @@ -148,10 +145,7 @@ def __init__(self, api_key: str, secret_key: str): @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - credentials_kwargs = { - "api_key": credentials['api_key'], - "secret_key": credentials['secret_key'] - } + credentials_kwargs = {"api_key": credentials["api_key"], "secret_key": credentials["secret_key"]} return credentials_kwargs def _handle_error(self, code: int, msg: str): @@ -187,13 +181,13 @@ def _handle_error(self, code: int, msg: str): 336105: BadRequestError, 336200: InternalServerError, 336303: BadRequestError, - 337006: BadRequestError + 337006: BadRequestError, } if code in error_map: raise error_map[code](msg) else: - raise InternalServerError(f'Unknown error: {msg}') + raise InternalServerError(f"Unknown error: {msg}") def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 8109949b1d80b7..07b970f8104c8f 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -15,33 +15,39 @@ class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - def __init__(self, content: str, role: str = 'user') -> None: + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role -class ErnieBotModel(_CommonWenxin): - - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: +class ErnieBotModel(_CommonWenxin): + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -49,22 +55,23 @@ def generate(self, model: str, stream: bool, messages: list[ErnieMessage], access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, messages=messages_cloned, stream=stream, parameters=parameters, tools=tools, stop=stop, user=user + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) @@ -73,10 +80,11 @@ def generate(self, model: str, stream: bool, messages: list[ErnieMessage], def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str] + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') + raise BadRequestError(f"Invalid model: {model}") # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') @@ -85,86 +93,106 @@ def _check_parameters(self, model: str, parameters: dict[str, Any], # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') - - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + raise BadRequestError("stop item should not exceed 20 characters.") + + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') + raise BadRequestError("The number of messages should not be zero.") # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } - if 'max_tokens' in parameters and type(parameters['max_tokens']) == int: - body['max_output_tokens'] = parameters['max_tokens'] + if "max_tokens" in parameters and type(parameters["max_tokens"]) == int: + body["max_output_tokens"] = parameters["max_tokens"] - if 'presence_penalty' in parameters and type(parameters['presence_penalty']) == float: - body['penalty_score'] = parameters['presence_penalty'] + if "presence_penalty" in parameters and type(parameters["presence_penalty"]) == float: + body["penalty_score"] = parameters["presence_penalty"] if system_message: - body['system'] = system_message + body["system"] = system_message return body def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message @@ -173,19 +201,19 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -195,23 +223,23 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - result = data['result'] - is_end = data['is_end'] + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 140606298cac37..1ff0ac7ad27ce9 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -30,42 +30,82 @@ You should also complete the text started with ``` but not tell ``` directly. """ + class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ @@ -74,34 +114,44 @@ def _transform_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -113,10 +163,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -126,36 +176,53 @@ def tokens(text: str): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) @@ -180,41 +247,47 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: ErnieMessage + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -225,10 +298,7 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 10ac1a1861e29f..db323ae4c1c0c0 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -29,38 +29,38 @@ def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): access_token = self._get_access_token() - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url, data=dumps(body), headers=headers) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") return self._handle_embed_response(model, resp) def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: if len(texts) == 0: - raise BadRequestError('The number of texts should not be zero.') + raise BadRequestError("The number of texts should not be zero.") body = { - 'input': texts, - 'user_id': user, + "input": texts, + "user_id": user, } return body def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - embeddings = [v['embedding'] for v in data['data']] - _usage = data['usage'] - tokens = _usage['prompt_tokens'] - total_tokens = _usage['total_tokens'] + embeddings = [v["embedding"] for v in data["data"]] + _usage = data["usage"] + tokens = _usage["prompt_tokens"] + total_tokens = _usage["total_tokens"] return embeddings, tokens, total_tokens @@ -69,22 +69,23 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: return WenxinTextEmbedding(api_key, secret_key) - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ - Invoke text embedding model + Invoke text embedding model - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -94,7 +95,6 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], used_total_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -110,9 +110,8 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], _iter = range(0, len(inputs), max_chunks) for i in _iter: embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( - model, - inputs[i: i + max_chunks], - user) + model, inputs[i : i + max_chunks], user + ) used_tokens += _used_tokens used_total_tokens += _total_used_tokens batched_embeddings += embeddings_batch @@ -142,12 +141,12 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens def validate_credentials(self, model: str, credentials: Mapping) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -164,10 +163,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -178,7 +174,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bcf1bc..895af20bc8541d 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `ernie-bot` model for validate, - model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials - ) + model_instance.validate_credentials(model="ernie-bot", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index 0fbd0f55ec9d90..f2e2248680b8fc 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -18,40 +18,37 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 4760e8f1185e8d..b2c837dee100ee 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -65,88 +65,108 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), - ) + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials - - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + validate credentials + + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: raise ValueError( - f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + f"xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type" + ) if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.support_vision: - credentials['support_vision'] = True + credentials["support_vision"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -162,10 +182,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -217,30 +237,30 @@ def tokens(text: str): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -248,9 +268,9 @@ def tokens(text: str): def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): text += item.content @@ -259,7 +279,7 @@ def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: elif isinstance(item, AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -275,19 +295,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -297,7 +311,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -312,151 +326,144 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY, use_template=DefaultParameterName.PRESENCE_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' - 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they " + "appear in the text so far, increasing the model's likelihood to talk about new topics.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY, use_template=DefaultParameterName.FREQUENCY_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' - 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' - 'same line verbatim.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their " + "existing frequency in the text so far, decreasing the model's likelihood to repeat the " + "same line verbatim.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 - ) + precision=2, + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + raise ValueError(f"xinference model ability {extra_args.model_ability} is not supported") features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] - api_key = credentials.get('api_key') or "abc" + api_key = credentials.get("api_key") or "abc" client = OpenAI( base_url=f'{credentials["server_url"]}/v1', @@ -466,34 +473,29 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM ) xinference_client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'presence_penalty': model_parameters.get('presence_penalty', 0.0), - 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), + "presence_penalty": model_parameters.get("presence_penalty", 0.0), + "frequency_penalty": model_parameters.get("frequency_penalty", 0.0), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - vision = credentials.get('support_vision', False) + generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] + vision = credentials.get("support_vision", False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, @@ -501,34 +503,34 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError("xinference tool calls does not support stream mode") + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_completion_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError(f"xinference model handle type {type(xinference_model)} is not supported") - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -539,21 +541,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -563,23 +563,25 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -595,15 +597,15 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + content=assistant_message.content, tool_calls=assistant_prompt_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -615,13 +617,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -629,7 +636,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -646,32 +653,31 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: @@ -687,11 +693,16 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -699,14 +710,9 @@ def _handle_completion_generate_response(self, model: str, credentials: dict, pr assistant_message = resp.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[assistant_prompt_message], tools=[], is_completion_model=True ) @@ -724,13 +730,18 @@ def _handle_completion_generate_response(self, model: str, credentials: dict, pr return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -739,40 +750,33 @@ def _handle_completion_stream_response(self, model: str, credentials: dict, prom delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -807,15 +811,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError + PermissionDeniedError, ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index d809537479f40d..1582fe43b95d93 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -22,10 +22,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,24 +45,16 @@ def _invoke(self, model: str, credentials: dict, :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} - params = { - 'documents': docs, - 'query': query, - 'top_n': top_n, - 'return_documents': True - } + params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) response = handle.rerank(**params) @@ -69,27 +67,24 @@ def _invoke(self, model: str, credentials: dict, response = handle.rerank(**params) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] + index = result["index"] + page_content = result["document"] if isinstance(result["document"], str) else result["document"]["text"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check if score_threshold is not None: - if result['relevance_score'] >= score_threshold: + if result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -100,34 +95,35 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a rerank model') + "please check model type, the model you want to invoke is not a rerank model" + ) self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -143,53 +139,38 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): - def rerank( - self, - documents: list[str], - query: str, - top_n: Optional[int] = None, - max_chunks_per_doc: Optional[int] = None, - return_documents: Optional[bool] = None, - **kwargs + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs, ): url = f"{self._base_url}/v1/rerank" request_body = { @@ -205,8 +186,6 @@ def rerank( response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: - raise InvokeServerUnavailableError( - f"Failed to rerank documents, detail: {response.json()['detail']}" - ) + raise InvokeServerUnavailableError(f"Failed to rerank documents, detail: {response.json()['detail']}") response_data = response.json() return response_data diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 62b77f22e59c87..54c8b51654bfbc 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -21,9 +21,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): Model class for Xinference speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -44,27 +42,28 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulAudioModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + "please check model type, the model you want to invoke is not a audio model" + ) audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self.invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -80,23 +79,11 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _speech2text_invoke( @@ -122,21 +109,17 @@ def _speech2text_invoke( :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) response = handle.transcriptions( - audio=file, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature + audio=file, language=language, prompt=prompt, response_format=response_format, temperature=temperature ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -145,17 +128,15 @@ def _speech2text_invoke( def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 3a8d704c25838c..ac704e7de8361a 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -41,12 +42,12 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) @@ -70,13 +71,11 @@ class EmbeddingData(TypedDict): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage ) return result @@ -105,12 +104,12 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=server_url, model_uid=model_uid, @@ -118,8 +117,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - if server_url.endswith('/'): + credentials["max_tokens"] = extra_args.max_tokens + if server_url.endswith("/"): server_url = server_url[:-1] client = Client( @@ -133,32 +132,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}") except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -172,10 +163,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -186,28 +174,26 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 8cc99fef7c23c2..60db15130299df 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -19,92 +19,91 @@ class XinferenceText2SpeechModel(TTSModel): - def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'ChatTTS': { - 'all': [ - {'name': 'Alloy', 'value': 'alloy'}, - {'name': 'Echo', 'value': 'echo'}, - {'name': 'Fable', 'value': 'fable'}, - {'name': 'Onyx', 'value': 'onyx'}, - {'name': 'Nova', 'value': 'nova'}, - {'name': 'Shimmer', 'value': 'shimmer'}, + "ChatTTS": { + "all": [ + {"name": "Alloy", "value": "alloy"}, + {"name": "Echo", "value": "echo"}, + {"name": "Fable", "value": "fable"}, + {"name": "Onyx", "value": "onyx"}, + {"name": "Nova", "value": "nova"}, + {"name": "Shimmer", "value": "shimmer"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ try: - if ("/" in credentials['model_uid'] or - "?" in credentials['model_uid'] or - "#" in credentials['model_uid']): + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'text-to-audio' not in extra_param.model_ability: + if "text-to-audio" not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a text-to-audio model') + "please check model type, the model you want to invoke is not a text-to-audio model" + ) if extra_param.model_family and extra_param.model_family in self.model_voices: - credentials['audio_model_name'] = extra_param.model_family + credentials["audio_model_name"] = extra_param.model_family else: - credentials['audio_model_name'] = '__default' + credentials["audio_model_name"] = "__default" self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -120,18 +119,16 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -147,40 +144,28 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = credentials.get('audio_model_name', '__default') + audio_model_name = credentials.get("audio_model_name", "__default") for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] else: all_voices = [] for lang, lang_voices in voices.items(): all_voices.extend(lang_voices) return all_voices - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -194,8 +179,7 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -205,48 +189,42 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str :param voice: model timbre :return: text translated to audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] try: - api_key = credentials.get('api_key') - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + api_key = credentials.get("api_key") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} handle = RESTfulAudioModelHandle( - credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + credentials["model_uid"], credentials["server_url"], auth_headers=auth_headers ) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit( - handle.speech, - input=sentences[i], - voice=voice, - response_format="mp3", - speed=1.0, - stream=False - ) - for i in range(len(sentences))] + futures = [ + executor.submit( + handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): response = future.result() for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] else: response = handle.speech( - input=content_text.strip(), - voice=voice, - response_format="mp3", - speed=1.0, - stream=False + input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False ) for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 151166f165cb70..6ad10e690ddc17 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -18,9 +18,17 @@ class XinferenceModelExtraParameter: support_vision: bool = False model_family: Optional[str] - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, - model_family: Optional[str]) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + support_vision: bool, + max_tokens: int, + context_length: int, + model_family: Optional[str], + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -30,9 +38,11 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis self.context_length = context_length self.model_family = model_family + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: @@ -40,16 +50,16 @@ def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -58,55 +68,57 @@ def _clean_cache() -> None: @staticmethod def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) - headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}" + ) response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - model_family = response_json.get('model_family', None) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) + model_family = response_json.get("model_family", None) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif response_json.get('model_type') == 'audio': - model_handle_type = 'audio' - if model_family and model_family in ['ChatTTS', 'CosyVoice', 'FishAudio']: - model_ability.append('text-to-audio') + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif response_json.get("model_type") == "audio": + model_handle_type = "audio" + if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + model_ability.append("text-to-audio") else: - model_ability.append('audio-to-text') - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + model_ability.append("audio-to-text") + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError('xinference model handle type is not supported') + raise NotImplementedError("xinference model handle type is not supported") - support_function_call = 'tools' in model_ability - support_vision = 'vision' in model_ability - max_tokens = response_json.get('max_tokens', 512) + support_function_call = "tools" in model_ability + support_vision = "vision" in model_ability + max_tokens = response_json.get("max_tokens", 512) - context_length = response_json.get('context_length', 2048) + context_length = response_json.get("context_length", 2048) return XinferenceModelExtraParameter( model_format=model_format, @@ -116,5 +128,5 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: st support_vision=support_vision, max_tokens=max_tokens, context_length=context_length, - model_family=model_family + model_family=model_family, ) diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333be9e7..5ab7fd126e3082 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -14,11 +14,17 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) # yi-vl-plus not support system prompt yet. @@ -27,7 +33,9 @@ def _invoke(self, model: str, credentials: dict, for message in prompt_messages: if not isinstance(message, SystemPromptMessage): prompt_message_except_system.append(message) - return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream + ) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -36,8 +44,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -55,8 +62,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -76,10 +84,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -110,10 +118,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.lingyiwanwu.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/yi/yi.py b/api/core/model_runtime/model_providers/yi/yi.py index 691c7aa3711beb..9599acb22b505a 100644 --- a/api/core/model_runtime/model_providers/yi/yi.py +++ b/api/core/model_runtime/model_providers/yi/yi.py @@ -8,7 +8,6 @@ class YiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `yi-34b-chat-0205` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='yi-34b-chat-0205', - credentials=credentials - ) + model_instance.validate_credentials(model="yi-34b-chat-0205", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py index 6930a5ed0134b0..befc3de021e1f2 100644 --- a/api/core/model_runtime/model_providers/zhinao/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -7,11 +7,17 @@ class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.360.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.360.cn/v1" diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py index 44b36c9f51edd7..2a263292f98f14 100644 --- a/api/core/model_runtime/model_providers/zhinao/zhinao.py +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -8,7 +8,6 @@ class ZhinaoProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `360gpt-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='360gpt-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="360gpt-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 3412d8100f8e4c..fa95232f717d78 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials.get("zhipuai_api_key"), + "api_key": credentials["api_key"] if "api_key" in credentials else credentials.get("zhipuai_api_key"), } return credentials_kwargs @@ -38,5 +37,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index b2cdc7ad7a4644..484ac088dbc0c5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -35,12 +35,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,9 +67,9 @@ def _invoke(self, model: str, credentials: dict, # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -94,8 +99,13 @@ def _invoke(self, model: str, credentials: dict, # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -130,16 +140,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -154,15 +170,13 @@ def _generate(self, model: str, credentials_kwargs: dict, """ extra_model_kwargs = {} # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" - if stop and model!= 'glm-4v-plus': - extra_model_kwargs['stop'] = stop + if stop and model != "glm-4v-plus": + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') + raise ValueError("At least one message is required") if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: @@ -175,10 +189,10 @@ def _generate(self, model: str, credentials_kwargs: dict, if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ('glm-4v', 'glm-4v-plus'): + if model not in ("glm-4v", "glm-4v-plus"): # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -188,8 +202,11 @@ def _generate(self, model: str, credentials_kwargs: dict, # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: if copy_prompt_message.role == PromptMessageRole.USER: @@ -208,77 +225,66 @@ def _generate(self, model: str, credentials_kwargs: dict, else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v' or model == 'glm-4v-plus': + if model == "glm-4v" or model == "glm-4v-plus": params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if ( + prompt_message.role == PromptMessageRole.SYSTEM + or prompt_message.role == PromptMessageRole.TOOL + or prompt_message.role == PromptMessageRole.USER + ): + if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": + params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append({"role": "user", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] + params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] if stream: response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) @@ -287,47 +293,41 @@ def _generate(self, model: str, credentials_kwargs: dict, response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], - model_parameters: dict): + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict): messages = [ - { - 'role': message.role.value, - 'content': self._construct_glm_4v_messages(message.content) - } + {"role": message.role.value, "content": self._construct_glm_4v_messages(message.content)} for message in prompt_messages ] - params = { - 'model': model, - 'messages': messages, - **model_parameters - } + params = {"model": model, "messages": messages, **model_parameters} return params def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): - return [{'type': 'text', 'text': prompt_message}] + return [{"type": "text", "text": prompt_message}] return [ - {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} - if item.type == PromptMessageContentType.IMAGE else - {'type': 'text', 'text': item.data} - + {"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}} + if item.type == PromptMessageContentType.IMAGE + else {"type": "text", "text": item.data} for item in prompt_message ] def _remove_image_header(self, image: str) -> str: - if image.startswith('data:image'): - return image.split(',')[1] + if image.startswith("data:image"): + return image.split(",")[1] return image - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -336,12 +336,12 @@ def _handle_generate_response(self, model: str, :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -349,11 +349,11 @@ def _handle_generate_response(self, model: str, function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' + text += choice.message.content or "" prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -365,20 +365,20 @@ def _handle_generate_response(self, model: str, result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls - ), + message=AssistantPromptMessage(content=text, tool_calls=assistant_tool_calls), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -387,19 +387,19 @@ def _handle_generate_stream_response(self, model: str, :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -407,17 +407,16 @@ def _handle_generate_stream_response(self, model: str, function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens @@ -429,24 +428,22 @@ def _handle_generate_stream_response(self, model: str, yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason - ) + index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -473,18 +470,16 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) if tools and len(tools) > 0: text += "\n\nTools:" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72e69c..ee20954381053d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -14,9 +14,9 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -27,16 +27,14 @@ def _invoke(self, model: str, credentials: dict, :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -50,7 +48,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ if len(texts) == 0: return 0 - + total_num_tokens = 0 for text in texts: total_num_tokens += self._get_num_tokens_by_gpt2(text) @@ -68,15 +66,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -100,7 +96,7 @@ def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tupl embedding_used_tokens += response.usage.total_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens - + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. @@ -111,8 +107,8 @@ def embed_query(self, text: str) -> list[float]: Embeddings for the text. """ return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -122,10 +118,7 @@ def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> Emb """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -136,7 +129,7 @@ def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> Emb price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py index c517d2dba5a2d1..e75aad6eb0eb53 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -19,12 +19,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='glm-4', - credentials=credentials - ) + model_instance.validate_credentials(model="glm-4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 4dcd03f5511b6f..bf9b093cb3f02b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,4 +1,3 @@ - from .__version__ import __version__ from ._client import ZhipuAI from .core._errors import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332ca80af..659f38d7ff32d2 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1 @@ - -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = "v2.0.1" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 6588d1dd684900..df9e506095fab9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -20,14 +20,14 @@ class ZhipuAI(HttpClient): api_key: str def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if api_key is None: raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") @@ -38,6 +38,7 @@ def __init__( if base_url is None: base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ + super().__init__( version=__version__, base_url=base_url, @@ -58,9 +59,7 @@ def _auth_headers(self) -> dict[str, str]: return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): + if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): # if the '__init__' method raised an error, self would not have client attr return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5fe979c..1f8011973951b3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,25 +17,24 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus @@ -57,9 +56,7 @@ def create( "tools": tools, "tool_choice": tool_choice, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) @@ -71,16 +68,11 @@ def retrieve_completion_result( disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index 5c4ed4d1ba5922..ec29f33864203c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -20,24 +20,24 @@ def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: _cast_type = Completion _stream_cls = StreamResponse[ChatCompletionChunk] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 35d54592fd55c0..2308a204514e17 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -18,16 +18,16 @@ def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: _cast_type = EmbeddingsResponded if disable_strict_validation: @@ -41,9 +41,7 @@ def create( "user": user, "sensitive_word_check": sensitive_word_check, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08f3405b..f2ac74bffa8439 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -17,17 +17,16 @@ class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: if not is_file_content(file): prefix = f"Expected file input `{file!r}`" @@ -44,21 +43,19 @@ def create( "purpose": purpose, }, files=files, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FileObject, ) def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca4567e3..dc30bd33edfbbc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index b860de192a612f..3d2e9208a11f17 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -16,21 +16,20 @@ class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( "/fine_tuning/jobs", @@ -42,34 +41,30 @@ def create( "validation_file": validation_file, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", @@ -93,7 +88,6 @@ def list_events( extra_headers: Headers | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: - return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 8eae1216d09e89..2692b093af8b43 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -18,21 +18,21 @@ def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: _cast_type = ImagesResponded if disable_strict_validation: @@ -50,11 +50,7 @@ def generations( "user": user, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index a2a438b8f3d355..1027c1bc5b1e55 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -17,7 +17,10 @@ class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: + def __init__( + self, + message: str, + ) -> None: super().__init__(message) @@ -31,24 +34,19 @@ def __init__(self, message: str, *, response: httpx.Response) -> None: self.status_code = response.status_code -class APIRequestFailedError(APIStatusError): - ... +class APIRequestFailedError(APIStatusError): ... -class APIAuthenticationError(APIStatusError): - ... +class APIAuthenticationError(APIStatusError): ... -class APIReachLimitError(APIStatusError): - ... +class APIReachLimitError(APIStatusError): ... -class APIInternalError(APIStatusError): - ... +class APIInternalError(APIStatusError): ... -class APIServerFlowExceedError(APIStatusError): - ... +class APIServerFlowExceedError(APIStatusError): ... class APIResponseError(Exception): @@ -67,16 +65,11 @@ class APIResponseValidationError(APIResponseError): status_code: int response: httpx.Response - def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None - ) -> None: + def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: super().__init__( message=message or "Data returned by API invalid for expected schema.", request=response.request, - json_data=json_data + json_data=json_data, ) self.response = response self.status_code = response.status_code diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 65401f6c1cd03f..48eeb37c41ba7d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -48,13 +48,13 @@ class HttpClient: _default_stream_cls: type[StreamResponse[Any]] | None = None def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if timeout is None or isinstance(timeout, NotGiven): if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: @@ -76,7 +76,6 @@ def __init__( self._custom_headers = custom_headers or {} def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) if sub_url.is_relative_url: request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") @@ -86,16 +85,15 @@ def _prepare_url(self, url: str) -> URL: @property def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } + return { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } @property def _auth_headers(self): @@ -109,10 +107,7 @@ def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: return httpx_headers - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: + def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: kwargs: dict[str, Any] = {} json_data = request_param.json_data headers = self._prepare_headers(request_param) @@ -164,7 +159,6 @@ def _primitive_value_to_str(val) -> str: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) serialized: dict[str, object] = {} @@ -175,30 +169,25 @@ def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object return serialized def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> HttpResponse: - http_response = HttpResponse( - raw_response=response, - cast_type=cast_type, - client=self, - enable_stream=enable_stream, - stream_cls=stream_cls + raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls ) return http_response.parse() def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, ) -> ResponseT: if data is None: return cast(ResponseT, None) @@ -225,12 +214,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: request = self._prepare_request(params) @@ -259,81 +248,79 @@ def request( ) def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream - ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream) def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) - - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream, - stream_cls=stream_cls + opts = ClientRequestParam.construct( + method="post", json_data=body, files=make_httpx_files(files), url=path, **options ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls) + def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) + opts = ClientRequestParam.construct( + method="put", url=path, json_data=body, files=make_httpx_files(files), **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def _make_status_error(self, response) -> APIStatusError: @@ -355,11 +342,11 @@ def _make_status_error(self, response) -> APIStatusError: def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, - query: Query | None = None, + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + extra_body: Body | None = None, + query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -368,7 +355,7 @@ def make_user_request_input( if max_retries is not None: options["max_retries"] = max_retries if not isinstance(timeout, NotGiven): - options['timeout'] = timeout + options["timeout"] = timeout if query is not None: options["params"] = query if extra_body is not None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba8461e03..ac459151fc3a42 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -35,17 +35,14 @@ def get_max_retries(self, max_retries) -> int: @classmethod def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : - kwargs: dict[str, Any] = { - key: remove_notgiven_indict(value) for key, value in values.items() - } + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam: + kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()} client = cls() client.__dict__.update(kwargs) return client model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 2f831b6fc9ca73..56e60a793407cd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -26,13 +26,13 @@ class HttpResponse(Generic[R]): http_response: httpx.Response def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: HttpClient, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: self._cast_type = cast_type self._client = client @@ -52,8 +52,8 @@ def _parse(self) -> R: self._stream_cls( cast_type=cast(type, get_args(self._stream_cls)[0]), response=self.http_response, - client=self._client - ) + client=self._client, + ), ) return self._parsed cast_type = self._cast_type diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 66afbfd10780cc..3566c6b332bfce 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -16,16 +16,15 @@ class StreamResponse(Generic[ResponseT]): - response: httpx.Response _cast_type: type[ResponseT] def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, ) -> None: self.response = response self._cast_type = cast_type @@ -39,7 +38,6 @@ def __iter__(self) -> Iterator[ResponseT]: yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() iterator = sse_line_parser.iter_lines(self.response.iter_lines()) @@ -63,11 +61,7 @@ def __stream__(self) -> Iterator[ResponseT]: class Event: def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None + self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None ): self._event = event self._data = data @@ -76,21 +70,28 @@ def __init__( def __repr__(self): data_len = len(self._data) if self._data else 0 - return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + return ( + f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + ) @property - def event(self): return self._event + def event(self): + return self._event @property - def data(self): return self._data + def data(self): + return self._data - def json_data(self): return json.loads(self._data) + def json_data(self): + return json.loads(self._data) @property - def id(self): return self._id + def id(self): + return self._id @property - def retry(self): return self._retry + def retry(self): + return self._retry class SSELineParser: @@ -107,19 +108,11 @@ def __init__(self): def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: for line in lines: - line = line.rstrip('\n') + line = line.rstrip("\n") if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: + if self._event is None and not self._data and self._id is None and self._retry is None: continue - sse_event = Event( - event=self._event, - data='\n'.join(self._data), - id=self._id, - retry=self._retry - ) + sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) self._event = None self._data = [] self._id = None @@ -134,7 +127,7 @@ def decode_line(self, line: str): field, _p, value = line.partition(":") - if value.startswith(' '): + if value.startswith(" "): value = value[1:] if field == "data": self._data.append(value) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d25120f0..a0645b09168821 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c50c357d..4b3a929a2b816d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 917bda75767b9d..75f76fe969faf7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -6,7 +6,6 @@ class FileObject(BaseModel): - id: Optional[str] = None bytes: Optional[int] = None created_at: Optional[int] = None @@ -18,7 +17,6 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): - object: Optional[str] = None data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaff0dd18..1d3930286b89d3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index fe705d6943a447..e4f35414756400 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -4,9 +4,9 @@ class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: @@ -36,8 +36,9 @@ def _validate_and_filter_credential_form_schemas(self, return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -49,7 +50,7 @@ def _validate_credential_form_schema(self, credential_form_schema: CredentialFor if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError(f"Variable {credential_form_schema.variable} is required") else: # Get the value of default if credential_form_schema.default: @@ -65,23 +66,25 @@ def _validate_credential_form_schema(self, credential_form_schema: CredentialFor # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError(f"Variable {credential_form_schema.variable} should be string") if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + raise ValueError(f"Variable {credential_form_schema.variable} is not in options") if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in ["true", "false"]: + raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index c4786fad5d4c08..7d1644d13481b1 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -4,7 +4,6 @@ class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): self.model_type = model_type self.model_credential_schema = model_credential_schema diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index c945016534ed8a..6dff2428ca0c34 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -3,7 +3,6 @@ class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 5078f00bfa26d0..ec1bad5698f2eb 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,11 +18,10 @@ from pydantic_extra_types.color import Color -def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any -) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) + # Taken from Pydantic v1 as is def isoformat(o: Union[datetime.date, datetime.time]) -> str: return o.isoformat() @@ -82,11 +81,9 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( - tuple - ) + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -149,17 +146,13 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (value is not None or not exclude_none) and key in allowed_keys ): diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index c68a554471703f..2067092d80f582 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index f96e2a1c214f5d..094ad7863603dc 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -44,32 +44,29 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - params = ModerationInputParams( - app_id=self.app_id, - inputs=inputs, - query=query - ) + if self.config["inputs_config"]["enabled"]: + params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - params = ModerationOutputParams( - app_id=self.app_id, - text=text - ) + if self.config["outputs_config"]["enabled"]: + params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) @@ -80,9 +77,10 @@ def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, para @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 757dd2ab46567f..4b91f20184b5e9 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,8 +8,8 @@ class ModerationAction(Enum): - DIRECT_OUTPUT = 'direct_output' - OVERRIDDEN = 'overridden' + DIRECT_OUTPUT = "direct_output" + OVERRIDDEN = "overridden" class ModerationInputsResult(BaseModel): @@ -31,6 +31,7 @@ class Moderation(Extensible, ABC): """ The base class of moderation. """ + module: ExtensionModule = ExtensionModule.MODERATION def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46dfacbc9e0d83..336c16eecf784e 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -13,13 +13,14 @@ class InputModeration: def check( - self, app_id: str, + self, + app_id: str, tenant_id: str, app_config: AppConfig, inputs: dict, query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -39,10 +40,7 @@ def check( moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( - name=moderation_type, - app_id=app_id, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config.config + name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) with measure_time() as timer: @@ -55,7 +53,7 @@ def check( message_id=message_id, moderation_result=moderation_result, inputs=inputs, - timer=timer + timer=timer, ) ) diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index ca562ad987cada..17e48b8fbe5769 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -25,31 +25,35 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: + if self.config["outputs_config"]["enabled"]: # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({'text': text}, keywords_list) - preset_response = self.config['outputs_config']['preset_response'] + flagged = self._is_violated({"text": text}, keywords_list) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index fee51007ebeed7..6465de23b9a2de 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -21,37 +21,36 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - flagged = self._is_violated({'text': text}) - preset_response = self.config['outputs_config']['preset_response'] + if self.config["outputs_config"]["enabled"]: + flagged = self._is_violated({"text": text}) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict): - text = '\n'.join(str(inputs.values())) + text = "\n".join(str(inputs.values())) model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider="openai", - model_type=ModelType.MODERATION, - model="text-moderation-stable" + tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation( - text=text - ) + openai_moderation = model_instance.invoke_moderation(text=text) return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 69e28770c3415f..d8d794be18f21e 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -29,7 +29,7 @@ class OutputModeration(BaseModel): thread: Optional[threading.Thread] = None thread_running: bool = True - buffer: str = '' + buffer: str = "" is_final_chunk: bool = False final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -50,11 +50,7 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> self.buffer = completion self.is_final_chunk = True - result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=completion - ) + result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: return completion @@ -65,21 +61,19 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> final_output = result.text if public_event: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) return final_output def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE - thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE - }) + thread = threading.Thread( + target=self.worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, + }, + ) thread.start() @@ -104,9 +98,7 @@ def worker(self, flask_app: Flask, buffer_size: int): current_length = buffer_length result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=moderation_buffer + tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) if not result or not result.flagged: @@ -116,16 +108,11 @@ def worker(self, flask_app: Flask, buffer_size: int): final_output = result.preset_response self.final_output = final_output else: - final_output = result.text + self.buffer[len(moderation_buffer):] + final_output = result.text + self.buffer[len(moderation_buffer) :] # trigger replace event if self.thread_running: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -133,10 +120,7 @@ def worker(self, flask_app: Flask, buffer_size: int): def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: try: moderation_factory = ModerationFactory( - name=self.rule.type, - app_id=app_id, - tenant_id=tenant_id, - config=self.rule.config + name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index c7af8e296339c8..f7b882fc71d48e 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -23,4 +23,4 @@ def trace(self, trace_info: BaseTraceInfo): Abstract method to trace activities. Subclasses must implement specific tracing logic for activities. """ - ... \ No newline at end of file + ... diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab9302..0ab2139a8888c8 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -4,14 +4,15 @@ class TracingProviderEnum(Enum): - LANGFUSE = 'langfuse' - LANGSMITH = 'langsmith' + LANGFUSE = "langfuse" + LANGSMITH = "langsmith" class BaseTracingConfig(BaseModel): """ Base model class for tracing """ + ... @@ -19,16 +20,17 @@ class LangfuseConfig(BaseTracingConfig): """ Model class for Langfuse tracing config. """ + public_key: str secret_key: str - host: str = 'https://api.langfuse.com' + host: str = "https://api.langfuse.com" @field_validator("host") def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.langfuse.com' - if not v.startswith('https://') and not v.startswith('http://'): - raise ValueError('host must start with https:// or http://') + v = "https://api.langfuse.com" + if not v.startswith("https://") and not v.startswith("http://"): + raise ValueError("host must start with https:// or http://") return v @@ -37,15 +39,16 @@ class LangSmithConfig(BaseTracingConfig): """ Model class for Langsmith tracing config. """ + api_key: str project: str - endpoint: str = 'https://api.smith.langchain.com' + endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.smith.langchain.com' - if not v.startswith('https://'): - raise ValueError('endpoint must start with https://') + v = "https://api.smith.langchain.com" + if not v.startswith("https://"): + raise ValueError("endpoint must start with https://") return v diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a1443f0691233b..a3ce27d5d44ec4 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -23,6 +23,7 @@ def ensure_type(cls, v): else: return "" + class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any conversation_id: Optional[str] = None @@ -98,23 +99,24 @@ class GenerateNameTraceInfo(BaseTraceInfo): conversation_id: Optional[str] = None tenant_id: str + trace_info_info_map = { - 'WorkflowTraceInfo': WorkflowTraceInfo, - 'MessageTraceInfo': MessageTraceInfo, - 'ModerationTraceInfo': ModerationTraceInfo, - 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, - 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, - 'ToolTraceInfo': ToolTraceInfo, - 'GenerateNameTraceInfo': GenerateNameTraceInfo, + "WorkflowTraceInfo": WorkflowTraceInfo, + "MessageTraceInfo": MessageTraceInfo, + "ModerationTraceInfo": ModerationTraceInfo, + "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo, + "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, + "ToolTraceInfo": ToolTraceInfo, + "GenerateNameTraceInfo": GenerateNameTraceInfo, } class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation' - WORKFLOW_TRACE = 'workflow' - MESSAGE_TRACE = 'message' - MODERATION_TRACE = 'moderation' - SUGGESTED_QUESTION_TRACE = 'suggested_question' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' - TOOL_TRACE = 'tool' - GENERATE_NAME_TRACE = 'generate_conversation_name' + CONVERSATION_TRACE = "conversation" + WORKFLOW_TRACE = "workflow" + MESSAGE_TRACE = "message" + MODERATION_TRACE = "moderation" + SUGGESTED_QUESTION_TRACE = "suggested_question" + DATASET_RETRIEVAL_TRACE = "dataset_retrieval" + TOOL_TRACE = "tool" + GENERATE_NAME_TRACE = "generate_conversation_name" diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index f3fc46d99a8692..8cbf162bf29e91 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -35,38 +35,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field( - None, description="Serialized data of the run" - ) + serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field( - None, description="Session ID associated with the run" - ) - session_name: Optional[str] = Field( - None, description="Session name associated with the run" - ) - reference_example_id: Optional[str] = Field( - None, description="Reference example ID associated with the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + session_id: Optional[str] = Field(None, description="Session ID associated with the run") + session_name: Optional[str] = Field(None, description="Session name associated with the run") + reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") def ensure_dict(cls, v, info: ValidationInfo): @@ -75,9 +57,9 @@ def ensure_dict(cls, v, info: ValidationInfo): if v == {} or v is None: return v usage_metadata = { - "input_tokens": values.get('input_tokens', 0), - "output_tokens": values.get('output_tokens', 0), - "total_tokens": values.get('total_tokens', 0), + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) if isinstance(v, str): @@ -143,25 +125,15 @@ def format_time(cls, v, info: ValidationInfo): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") end_time: Optional[datetime | str] = Field(None, description="End time of the run") error: Optional[str] = Field(None, description="Error message of the run") inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9cbc805fe7dc31..eea7bb3535fe8c 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -159,8 +159,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): run_type = LangSmithRunType.llm metadata.update( { - 'ls_provider': process_data.get('model_provider', ''), - 'ls_model_name': process_data.get('model_name', ''), + "ls_provider": process_data.get("model_provider", ""), + "ls_model_name": process_data.get("model_name", ""), } ) elif node_type == "knowledge-retrieval": @@ -385,12 +385,10 @@ def get_project_url(self): start_time=datetime.now(), ) - project_url = self.langsmith_client.get_run_url(run=run_data, - project_id=self.project_id, - project_name=self.project_name) - return project_url.split('/r/')[0] + project_url = self.langsmith_client.get_run_url( + run=run_data, project_id=self.project_id, project_name=self.project_name + ) + return project_url.split("/r/")[0] except Exception as e: logger.debug(f"LangSmith get run url failed: {str(e)}") raise ValueError(f"LangSmith get run url failed: {str(e)}") - - diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index aefab6ed16a1d8..d6156e479af163 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -36,17 +36,17 @@ provider_config_map = { TracingProviderEnum.LANGFUSE.value: { - 'config_class': LangfuseConfig, - 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host', 'project_key'], - 'trace_instance': LangFuseDataTrace + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, }, TracingProviderEnum.LANGSMITH.value: { - 'config_class': LangSmithConfig, - 'secret_keys': ['api_key'], - 'other_keys': ['project', 'endpoint'], - 'trace_instance': LangSmithDataTrace - } + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + }, } @@ -64,14 +64,17 @@ def encrypt_tracing_config( :return: encrypted tracing configuration """ # Get the configuration class and the keys that require encryption - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} # Encrypt necessary keys for key in secret_keys: if key in tracing_config: - if '*' in tracing_config[key]: + if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config new_config[key] = current_trace_config.get(key, tracing_config[key]) else: @@ -94,8 +97,11 @@ def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_c :param tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in tracing_config: @@ -114,8 +120,11 @@ def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: :param decrypt_tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in decrypt_tracing_config: @@ -133,9 +142,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None @@ -164,21 +175,21 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter( - App.id == app_id - ).first() + app: App = db.session.query(App).filter(App.id == app_id).first() app_ops_trace_config = json.loads(app.tracing) if app.tracing else None if app_ops_trace_config is not None: - tracing_provider = app_ops_trace_config.get('tracing_provider') + tracing_provider = app_ops_trace_config.get("tracing_provider") else: return None # decrypt_token decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) - if app_ops_trace_config.get('enabled'): - trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \ - provider_config_map[tracing_provider]['config_class'] + if app_ops_trace_config.get("enabled"): + trace_instance, config_class = ( + provider_config_map[tracing_provider]["trace_instance"], + provider_config_map[tracing_provider]["config_class"], + ) tracing_instance = trace_instance(config_class(**decrypt_trace_config)) return tracing_instance @@ -192,9 +203,11 @@ def get_app_config_through_message_id(cls, message_id: str): conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if conversation_data.app_model_config_id: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation_data.app_model_config_id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .first() + ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs @@ -231,10 +244,7 @@ def get_app_tracing_config(cls, app_id: str): """ app: App = db.session.query(App).filter(App.id == app_id).first() if not app.tracing: - return { - "enabled": False, - "tracing_provider": None - } + return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) return app_trace_config @@ -246,8 +256,10 @@ def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str) :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() @@ -259,8 +271,10 @@ def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_key() @@ -272,8 +286,10 @@ def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_url() @@ -287,7 +303,7 @@ def __init__( conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, - **kwargs + **kwargs, ): self.trace_type = trace_type self.message_id = message_id @@ -310,9 +326,7 @@ def preprocess(self): self.workflow_run, self.conversation_id, self.user_id ), TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( - self.message_id, self.timer, **self.kwargs - ), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs ), @@ -337,12 +351,8 @@ def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = ( - json.loads(workflow_run.inputs) if workflow_run.inputs else {} - ) - workflow_run_outputs = ( - json.loads(workflow_run.outputs) if workflow_run.outputs else {} - ) + workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} + workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_version = workflow_run.version error = workflow_run.error if workflow_run.error else "" @@ -352,17 +362,18 @@ def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - tenant_id=tenant_id, - app_id=workflow_run.app_id, - workflow_run_id=workflow_run.id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog) + .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) + .first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None # get message_id - message_data = db.session.query(Message.id).filter_by( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id - ).first() + message_data = ( + db.session.query(Message.id) + .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) + .first() + ) message_id = str(message_data.id) if message_data else None metadata = { @@ -470,9 +481,9 @@ def moderation_trace(self, message_id, timer, **kwargs): # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( @@ -510,9 +521,9 @@ def suggested_question_trace(self, message_id, timer, **kwargs): # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( @@ -569,9 +580,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get('tool_name') - tool_inputs = kwargs.get('tool_inputs') - tool_outputs = kwargs.get('tool_outputs') + tool_name = kwargs.get("tool_name") + tool_inputs = kwargs.get("tool_inputs") + tool_outputs = kwargs.get("tool_outputs") message_data = get_message_data(message_id) if not message_data: return {} @@ -586,11 +597,11 @@ def tool_trace(self, message_id, timer, **kwargs): if tool_name in agent_thought.tools: created_time = agent_thought.created_at tool_meta_data = agent_thought.tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - time_cost = tool_meta_data.get('time_cost', 0) + tool_config = tool_meta_data.get("tool_config", {}) + time_cost = tool_meta_data.get("time_cost", 0) end_time = created_time + timedelta(seconds=time_cost) - error = tool_meta_data.get('error', "") - tool_parameters = tool_meta_data.get('tool_parameters', {}) + error = tool_meta_data.get("error", "") + tool_parameters = tool_meta_data.get("tool_parameters", {}) metadata = { "message_id": message_id, "tool_name": tool_name, @@ -715,9 +726,7 @@ def run(self): def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run - ) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 3b2e04abb73288..498685b3426d37 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -20,19 +20,19 @@ def get_message_data(message_id): @contextmanager def measure_time(): - timing_info = {'start': datetime.now(), 'end': None} + timing_info = {"start": datetime.now(), "end": None} try: yield timing_info finally: - timing_info['end'] = datetime.now() + timing_info["end"] = datetime.now() def replace_text_with_content(data): if isinstance(data, dict): new_data = {} for key, value in data.items(): - if key == 'text': - new_data['content'] = value + if key == "text": + new_data["content"] = value else: new_data[key] = replace_text_with_content(value) return new_data diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 22420fea2cc02f..ce8038d14e2741 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ + def __init__(self, with_variable_tmpl: bool = False) -> None: self.with_variable_tmpl = with_variable_tmpl - def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, - query: str, - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def get_prompt( + self, + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], + inputs: dict, + query: str, + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: inputs = {key: str(value) for key, value in inputs.items()} prompt_messages = [] @@ -48,7 +52,7 @@ def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionMo context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( @@ -60,20 +64,22 @@ def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionMo context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _get_completion_model_prompt_messages(self, - prompt_template: CompletionModelPromptTemplate, - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _get_completion_model_prompt_messages( + self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -81,7 +87,7 @@ def _get_completion_model_prompt_messages(self, prompt_messages = [] - if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + if prompt_template.edition_type == "basic" or not prompt_template.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -96,15 +102,13 @@ def _get_completion_model_prompt_messages(self, role_prefix=role_prefix, prompt_template=prompt_template, prompt_inputs=prompt_inputs, - model_config=model_config + model_config=model_config, ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -122,16 +126,18 @@ def _get_completion_model_prompt_messages(self, return prompt_messages - def _get_chat_model_prompt_messages(self, - prompt_template: list[ChatModelMessage], - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def _get_chat_model_prompt_messages( + self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -142,22 +148,20 @@ def _get_chat_model_prompt_messages(self, for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + if prompt_item.edition_type == "basic" or not prompt_item.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) - elif prompt_item.edition_type == 'jinja2': + prompt = prompt_template.format(prompt_inputs) + elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs prompt = Jinja2Formatter.format(prompt, prompt_inputs) else: - raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) @@ -168,17 +172,14 @@ def _get_chat_model_prompt_messages(self, if query and query_prompt_template: prompt_template = PromptTemplateParser( - template=query_prompt_template, - with_variable_tmpl=self.with_variable_tmpl + template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs['#sys.query#'] = query + prompt_inputs["#sys.query#"] = query prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - query = prompt_template.format( - prompt_inputs - ) + query = prompt_template.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) @@ -203,7 +204,7 @@ def _get_chat_model_prompt_messages(self, last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -220,38 +221,39 @@ def _get_chat_model_prompt_messages(self, return prompt_messages def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#context#' in prompt_template.variable_keys: + if "#context#" in prompt_template.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs["#context#"] = context else: - prompt_inputs['#context#'] = '' + prompt_inputs["#context#"] = "" return prompt_inputs def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#query#' in prompt_template.variable_keys: + if "#query#" in prompt_template.variable_keys: if query: - prompt_inputs['#query#'] = query + prompt_inputs["#query#"] = query else: - prompt_inputs['#query#'] = '' + prompt_inputs["#query#"] = "" return prompt_inputs - def _set_histories_variable(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - raw_prompt: str, - role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigWithCredentialsEntity) -> dict: - if '#histories#' in prompt_template.variable_keys: + def _set_histories_variable( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity, + ) -> dict: + if "#histories#" in prompt_template.variable_keys: if memory: - inputs = {'#histories#': '', **prompt_inputs} + inputs = {"#histories#": "", **prompt_inputs} prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage( - content=prompt_template.format(prompt_inputs) - ) + tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) @@ -260,10 +262,10 @@ def _set_histories_variable(self, memory: TokenBufferMemory, memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant + ai_prefix=role_prefix.assistant, ) - prompt_inputs['#histories#'] = histories + prompt_inputs["#histories#"] = histories else: - prompt_inputs['#histories#'] = '' + prompt_inputs["#histories#"] = "" return prompt_inputs diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea9154fc..caa1793ea8c039 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ - def __init__(self, - model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage], - history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, - ): + + def __init__( + self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): self.model_config = model_config self.prompt_messages = prompt_messages self.history_messages = history_messages @@ -45,9 +47,7 @@ def get_prompt(self) -> list[PromptMessage]: model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - self.history_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,9 +63,7 @@ def get_prompt(self) -> list[PromptMessage]: # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 61df69163cba1c..c8e7b414dffe85 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel): """ Chat Message. """ + text: str role: PromptMessageRole - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class CompletionModelPromptTemplate(BaseModel): """ Completion Model Prompt Template. """ + text: str - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class MemoryConfig(BaseModel): """ Memory Config. """ + class RolePrefix(BaseModel): """ Role Prefix. """ + user: str assistant: str @@ -37,6 +41,7 @@ class WindowConfig(BaseModel): """ Window Config. """ + enabled: bool size: Optional[int] = None diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index da40534d99485b..e4b3a61cb4c001 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -7,39 +7,18 @@ "prompt": { "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " }, - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - } + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, - "stop": ["Human:"] + "stop": ["Human:"], } -CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } -} +CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}} -COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } -} +COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}} COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["Human:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["Human:"], } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { @@ -47,37 +26,20 @@ "prompt": { "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" }, - "conversation_histories_role": { - "user_prefix": "用户", - "assistant_prefix": "助手" - } + "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, }, - "stop": ["用户:"] + "stop": ["用户:"], } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } + "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["用户:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["用户:"], } diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index b86d3fa815d2ce..87acdb3c49cc01 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -9,75 +9,78 @@ class PromptTransform: - def _append_chat_histories(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _append_chat_histories( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: + def _get_history_messages_from_memory( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None, + ) -> str: """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } + kwargs = {"max_token_limit": max_token_limit} if human_prefix: - kwargs['human_prefix'] = human_prefix + kwargs["human_prefix"] = human_prefix if ai_prefix: - kwargs['ai_prefix'] = ai_prefix + kwargs["ai_prefix"] = ai_prefix if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs['message_limit'] = memory_config.window.size + kwargs["message_limit"] = memory_config.window.size - return memory.get_history_prompt_text( - **kwargs - ) + return memory.get_history_prompt_text(**kwargs) - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int) -> list[PromptMessage]: + def _get_history_messages_list_from_memory( + self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + ) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit, message_limit=memory_config.window.size - if (memory_config.window.enabled - and memory_config.window.size is not None - and memory_config.window.size > 0) - else None + if ( + memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + ) + else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fd7ed0181be2f2..13e5c5253e4d4b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -22,11 +22,11 @@ class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' + COMPLETION = "completion" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'ModelMode': + def value_of(cls, value: str) -> "ModelMode": """ Get value of given mode. @@ -36,7 +36,7 @@ def value_of(cls, value: str) -> 'ModelMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") prompt_file_contents = {} @@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform): Simple Prompt Transform for Chatbot App Basic Mode. """ - def get_prompt(self, - app_mode: AppMode, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: + def get_prompt( + self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode.value_of(model_config.mode) @@ -69,7 +70,7 @@ def get_prompt(self, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -80,19 +81,21 @@ def get_prompt(self, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages, stops - def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigWithCredentialsEntity, - pre_prompt: str, - inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, - ) -> tuple[str, dict]: + def get_prompt_str_and_rules( + self, + app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -101,74 +104,75 @@ def get_prompt_str_and_rules(self, app_mode: AppMode, pre_prompt=pre_prompt, has_context=context is not None, query_in_prompt=query is not None, - with_memory_prompt=histories is not None + with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} - for v in prompt_template_config['special_variable_keys']: + for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# - if v == '#context#': - variables['#context#'] = context if context else '' - elif v == '#query#': - variables['#query#'] = query if query else '' - elif v == '#histories#': - variables['#histories#'] = histories if histories else '' - - prompt_template = prompt_template_config['prompt_template'] + if v == "#context#": + variables["#context#"] = context if context else "" + elif v == "#query#": + variables["#query#"] = query if query else "" + elif v == "#histories#": + variables["#histories#"] = histories if histories else "" + + prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) - return prompt, prompt_template_config['prompt_rules'] + return prompt, prompt_template_config["prompt_rules"] - def get_prompt_template(self, app_mode: AppMode, - provider: str, - model: str, - pre_prompt: str, - has_context: bool, - query_in_prompt: bool, - with_memory_prompt: bool = False) -> dict: - prompt_rules = self._get_prompt_rule( - app_mode=app_mode, - provider=provider, - model=model - ) + def get_prompt_template( + self, + app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False, + ) -> dict: + prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] special_variable_keys = [] - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt' and has_context: - prompt += prompt_rules['context_prompt'] - special_variable_keys.append('#context#') - elif order == 'pre_prompt' and pre_prompt: - prompt += pre_prompt + '\n' + prompt = "" + for order in prompt_rules["system_prompt_orders"]: + if order == "context_prompt" and has_context: + prompt += prompt_rules["context_prompt"] + special_variable_keys.append("#context#") + elif order == "pre_prompt" and pre_prompt: + prompt += pre_prompt + "\n" pre_prompt_template = PromptTemplateParser(template=pre_prompt) custom_variable_keys = pre_prompt_template.variable_keys - elif order == 'histories_prompt' and with_memory_prompt: - prompt += prompt_rules['histories_prompt'] - special_variable_keys.append('#histories#') + elif order == "histories_prompt" and with_memory_prompt: + prompt += prompt_rules["histories_prompt"] + special_variable_keys.append("#histories#") if query_in_prompt: - prompt += prompt_rules.get('query_prompt', '{{#query#}}') - special_variable_keys.append('#query#') + prompt += prompt_rules.get("query_prompt", "{{#query#}}") + special_variable_keys.append("#query#") return { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, - "prompt_rules": prompt_rules + "prompt_rules": prompt_rules, } - def _get_chat_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_chat_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -178,7 +182,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, pre_prompt=pre_prompt, inputs=inputs, query=None, - context=context + context=context, ) if prompt and query: @@ -193,7 +197,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, ) ), prompt_messages=prompt_messages, - model_config=model_config + model_config=model_config, ) if query: @@ -203,15 +207,17 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, return prompt_messages, None - def _get_completion_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_completion_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=app_mode, @@ -219,13 +225,11 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, pre_prompt=pre_prompt, inputs=inputs, query=query, - context=context + context=context, ) if memory: - tmp_human_message = UserPromptMessage( - content=prompt - ) + tmp_human_message = UserPromptMessage(content=prompt) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( @@ -236,8 +240,8 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules.get('human_prefix', 'Human'), - ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) # get prompt @@ -248,10 +252,10 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, inputs=inputs, query=query, context=context, - histories=histories + histories=histories, ) - stops = prompt_rules.get('stops') + stops = prompt_rules.get("stops") if stops is not None and len(stops) == 0: stops = None @@ -277,22 +281,18 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict :param model: model name :return: """ - prompt_file_name = self._prompt_file_name( - app_mode=app_mode, - provider=provider, - model=model - ) + prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model) # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') - json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") + json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json") # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: + with open(json_file_path, encoding="utf-8") as json_file: content = json.load(json_file) # Store the content of the prompt file @@ -303,21 +303,21 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False - if provider == 'baichuan': + if provider == "baichuan": is_baichuan = True else: baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + if provider in baichuan_supported_providers and "baichuan" in model.lower(): is_baichuan = True if is_baichuan: if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' + return "baichuan_completion" else: - return 'baichuan_chat' + return "baichuan_chat" # common if app_mode == AppMode.COMPLETION: - return 'common_completion' + return "common_completion" else: - return 'common_chat' + return "common_chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index befdceeda505fb..29494db2214729 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -25,26 +25,29 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[ tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: - role = 'user' + role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' + role = "assistant" if isinstance(prompt_message, AssistantPromptMessage): - tool_calls = [{ - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls] + for tool_call in prompt_message.tool_calls + ] elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' + role = "system" elif prompt_message.role == PromptMessageRole.TOOL: - role = 'tool' + role = "tool" else: continue - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -53,27 +56,25 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[ text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content - prompt = { - "role": role, - "text": text, - "files": files - } - + prompt = {"role": role, "text": text, "files": files} + if tool_calls: - prompt['tool_calls'] = tool_calls + prompt["tool_calls"] = tool_calls prompts.append(prompt) else: prompt_message = prompt_messages[0] - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -82,21 +83,23 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[ text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content params = { - "role": 'user', + "role": "user", "text": text, } if files: - params['files'] = files + params["files"] = files prompts.append(params) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 3e68492df2f2d7..81115596757493 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -38,8 +38,8 @@ def replacer(match): return value prompt = re.sub(self.regex, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): - return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 67eee2c2944e7a..3a1fe300dfd311 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -90,8 +90,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict + tenant_id, provider_name_to_provider_records_dict ) # Get all provider model records of the workspace @@ -107,22 +106,20 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict \ - = self._get_all_provider_load_balancing_configs(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + tenant_id ) + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: - # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, - data=provider_entity, - name_func=lambda x: x.provider, + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, ): continue @@ -132,18 +129,11 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: # Convert to custom configuration custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records + tenant_id, provider_entity, provider_records, provider_model_records ) # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) @@ -173,14 +163,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) # Get provider load balancing configs - provider_load_balancing_configs \ - = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs + load_balancing_model_configs=provider_load_balancing_configs, ) provider_configuration = ProviderConfiguration( @@ -190,7 +181,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: using_provider_type=using_provider_type, system_configuration=system_configuration, custom_configuration=custom_configuration, - model_settings=model_settings + model_settings=model_settings, ) provider_configurations[provider_name] = provider_configuration @@ -219,7 +210,7 @@ def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: M return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -231,11 +222,14 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D :return: """ # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -244,20 +238,18 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next((model for model in available_models if model.model == "gpt-4"), - available_models[0]) + available_model = next( + (model for model in available_models if model.model == "gpt-4"), available_models[0] + ) default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), provider_name=available_model.provider.provider, - model_name=available_model.model + model_name=available_model.model, ) db.session.add(default_model) db.session.commit() @@ -276,8 +268,8 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: @@ -291,15 +283,13 @@ def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - all_models = provider_configurations.get_models( - model_type=model_type, - only_active=False - ) + all_models = provider_configurations.get_models(model_type=model_type, only_active=False) return all_models[0].provider.provider, all_models[0].model - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: + def update_default_model_record( + self, tenant_id: str, model_type: ModelType, provider: str, model: str + ) -> TenantDefaultModel: """ Update default model record. @@ -314,10 +304,7 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro raise ValueError(f"Provider {provider} does not exist.") # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) # check if the model is exist in available models model_names = [model.model for model in available_models] @@ -325,11 +312,14 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro raise ValueError(f"Model {model} does not exist.") # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # create or update TenantDefaultModel record if default_model: @@ -358,11 +348,7 @@ def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -379,11 +365,11 @@ def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: :return: """ # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() + provider_models = ( + db.session.query(ProviderModel) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .all() + ) provider_name_to_provider_model_records_dict = defaultdict(list) for provider_model in provider_models: @@ -399,10 +385,11 @@ def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPrefer :param tenant_id: workspace id :return: """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() + preferred_provider_types = ( + db.session.query(TenantPreferredModelProvider) + .filter(TenantPreferredModelProvider.tenant_id == tenant_id) + .all() + ) provider_name_to_preferred_provider_type_records_dict = { preferred_provider_type.provider_name: preferred_provider_type @@ -419,15 +406,17 @@ def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderM :param tenant_id: workspace id :return: """ - provider_model_settings = db.session.query(ProviderModelSetting) \ - .filter( - ProviderModelSetting.tenant_id == tenant_id - ).all() + provider_model_settings = ( + db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() + ) provider_name_to_provider_model_settings_dict = defaultdict(list) for provider_model_setting in provider_model_settings: - (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] - .append(provider_model_setting)) + ( + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + ) return provider_name_to_provider_model_settings_dict @@ -445,27 +434,30 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) else: - cache_result = cache_result.decode('utf-8') - model_load_balancing_enabled = cache_result == 'True' + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ - .filter( - LoadBalancingModelConfig.tenant_id == tenant_id - ).all() + provider_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() + ) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) for provider_load_balancing_config in provider_load_balancing_configs: - (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] - .append(provider_load_balancing_config)) + ( + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) + ) return provider_name_to_provider_load_balancing_model_configs_dict @staticmethod - def _init_trial_provider_records(tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: + def _init_trial_provider_records( + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] + ) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -489,8 +481,9 @@ def _init_trial_provider_records(tenant_id: str, if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: @@ -504,19 +497,22 @@ def _init_trial_provider_records(tenant_id: str, quota_type=ProviderQuotaType.TRIAL.value, quota_limit=quota.quota_limit, quota_used=0, - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() except IntegrityError: db.session.rollback() - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ) + .first() + ) if provider_record and not provider_record.is_valid: provider_record.is_valid = True @@ -526,11 +522,13 @@ def _init_trial_provider_records(tenant_id: str, return provider_name_to_provider_records_dict - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: + def _to_custom_configuration( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_records: list[Provider], + provider_model_records: list[ProviderModel], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -543,7 +541,8 @@ def _to_custom_configuration(self, # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get custom provider record @@ -563,7 +562,7 @@ def _to_custom_configuration(self, provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -572,11 +571,11 @@ def _to_custom_configuration(self, if not cached_provider_credentials: try: # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } + if ( + custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{") + ): + provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: @@ -590,28 +589,23 @@ def _to_custom_configuration(self, if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) + provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) + custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials @@ -621,9 +615,7 @@ def _to_custom_configuration(self, continue provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) # Get cached provider model credentials @@ -645,15 +637,13 @@ def _to_custom_configuration(self, provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials @@ -661,19 +651,15 @@ def _to_custom_configuration(self, CustomModelConfiguration( model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials + credentials=provider_model_credentials, ) ) - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) + return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: + def _to_system_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> SystemConfiguration: """ Convert to system configuration. @@ -685,11 +671,11 @@ def _to_system_configuration(self, # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) + if ( + provider_entity.provider not in hosting_configuration.provider_map + or not hosting_configuration.provider_map.get(provider_entity.provider).enabled + ): + return SystemConfiguration(enabled=False) provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) @@ -699,8 +685,9 @@ def _to_system_configuration(self, if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: @@ -712,7 +699,7 @@ def _to_system_configuration(self, quota_used=0, quota_limit=0, is_valid=False, - restrict_models=provider_quota.restrict_models + restrict_models=provider_quota.restrict_models, ) else: continue @@ -724,16 +711,15 @@ def _to_system_configuration(self, quota_unit=provider_hosting_configuration.quota_unit, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, ) quota_configurations.append(quota_configuration) if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) + return SystemConfiguration(enabled=False) current_quota_type = self._choice_current_using_quota_type(quota_configurations) @@ -745,7 +731,7 @@ def _to_system_configuration(self, provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -760,7 +746,8 @@ def _to_system_configuration(self, # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get decoding rsa key and cipher for decrypting credentials @@ -771,9 +758,7 @@ def _to_system_configuration(self, if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass @@ -781,9 +766,7 @@ def _to_system_configuration(self, current_using_credentials = provider_credentials # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) + provider_credentials_cache.set(credentials=current_using_credentials) else: current_using_credentials = cached_provider_credentials else: @@ -794,7 +777,7 @@ def _to_system_configuration(self, enabled=True, current_quota_type=current_quota_type, quota_configurations=quota_configurations, - credentials=current_using_credentials + credentials=current_using_credentials, ) @staticmethod @@ -809,8 +792,7 @@ def _choice_current_using_quota_type(quota_configurations: list[QuotaConfigurati """ # convert to dict quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations + quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } last_quota_configuration = None @@ -823,7 +805,7 @@ def _choice_current_using_quota_type(quota_configurations: list[QuotaConfigurati if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError('No quota type available') + raise ValueError("No quota type available") @staticmethod def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: @@ -840,10 +822,12 @@ def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema return secret_input_form_variables - def _to_model_settings(self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ - -> list[ModelSettings]: + def _to_model_settings( + self, + provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + ) -> list[ModelSettings]: """ Convert to model settings. :param provider_entity: provider entity @@ -854,7 +838,8 @@ def _to_model_settings(self, provider_entity: ProviderEntity, # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) model_settings = [] @@ -865,24 +850,28 @@ def _to_model_settings(self, provider_entity: ProviderEntity, load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: for load_balancing_model_config in load_balancing_model_configs: - if (load_balancing_model_config.model_name == provider_model_setting.model_name - and load_balancing_model_config.model_type == provider_model_setting.model_type): + if ( + load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type + ): if not load_balancing_model_config.enabled: continue if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={} - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + ) + ) continue provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) # Get cached provider model credentials @@ -897,7 +886,8 @@ def _to_model_settings(self, provider_entity: ProviderEntity, # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( - load_balancing_model_config.tenant_id) + load_balancing_model_config.tenant_id + ) for variable in model_credential_secret_variables: if variable in provider_model_credentials: @@ -905,30 +895,30 @@ def _to_model_settings(self, provider_entity: ProviderEntity, provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials=provider_model_credentials - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials, + ) + ) model_settings.append( ModelSettings( model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, - load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index eaad0e0f4c3a45..3c6ab2e4cfc56b 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -2,37 +2,35 @@ class CleanProcessor: - @classmethod def clean(cls, text: str, process_rule: dict) -> str: # default clean # remove invalid symbol - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) - rules = process_rule['rules'] if process_rule else None - if 'pre_processing_rules' in rules: + rules = process_rule["rules"] if process_rule else None + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text def filter_string(self, text): - return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f272c7..d3bc2f765e9654 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -1,12 +1,11 @@ """Abstract interface for document cleaner implementations.""" + from abc import ABC, abstractmethod class BaseCleaner(ABC): - """Interface for clean chunk content. - """ + """Interface for clean chunk content.""" @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py index 6a0b8c904603da..167a919e69aa31 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_extra_whitespace diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py index 6fc3a408dacbc6..9c682d29db376d 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" import re diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py index 87dc2d49faae87..0cdbb171e1081e 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_non_ascii_chars diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py index 974a28fef16f3f..9f42044a2d5db8 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -1,11 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """Replaces unicode quote characters, such as the \x91 character in a string.""" from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py index dfaf3a27874af2..32ae7217e878a5 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.translate import translate_text diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index ad9ee4f7cfad60..b1d6f93cff4547 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -12,17 +12,27 @@ class DataPostProcessor: - """Interface for data post-processing document. - """ + """Interface for data post-processing document.""" - def __init__(self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - reorder_enabled: bool = False): + def __init__( + self, + tenant_id: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reorder_enabled: bool = False, + ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) - def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def invoke( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -31,21 +41,26 @@ def invoke(self, query: str, documents: list[Document], score_threshold: Optiona return documents - def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + def _get_rerank_runner( + self, + reranking_mode: str, + tenant_id: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + ) -> Optional[RerankModelRunner | WeightRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: return WeightRerankRunner( tenant_id, Weights( vector_setting=VectorSetting( - vector_weight=weights['vector_setting']['vector_weight'], - embedding_provider_name=weights['vector_setting']['embedding_provider_name'], - embedding_model_name=weights['vector_setting']['embedding_model_name'], + vector_weight=weights["vector_setting"]["vector_weight"], + embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], + embedding_model_name=weights["vector_setting"]["embedding_model_name"], ), keyword_setting=KeywordSetting( - keyword_weight=weights['keyword_setting']['keyword_weight'], - ) - ) + keyword_weight=weights["keyword_setting"]["keyword_weight"], + ), + ), ) elif reranking_mode == RerankMode.RERANKING_MODEL.value: if reranking_model: @@ -53,9 +68,9 @@ def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_mode model_manager = ModelManager() rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], + provider=reranking_model["reranking_provider_name"], model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] + model=reranking_model["reranking_model_name"], ) except InvokeAuthorizationError: return None @@ -67,5 +82,3 @@ def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None - - diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 71297588a4e7e5..a9a0885241e4fd 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -2,7 +2,6 @@ class ReorderRunner: - def run(self, documents: list[Document]) -> list[Document]: # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list odd_elements = documents[::2] diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a3714c2fd3a38c..3073100746d360 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,37 +24,42 @@ def __init__(self, dataset: Dataset): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) + keywords_list = kwargs.get("keywords_list", None) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords(text.page_content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) @@ -63,97 +68,91 @@ def text_exists(self, id: str) -> bool: return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - k = kwargs.get('top_k', 4) + k = kwargs.get("top_k", 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.index_node_id == chunk_index - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) + .first() + ) if segment: - - documents.append(Document( - page_content=segment.content, - metadata={ - "doc_id": chunk_index, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - )) + documents.append( + Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + ) return documents def delete(self) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: db.session.delete(dataset_keyword_table) db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + if dataset_keyword_table.data_source_type != "database": + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": keyword_table - } + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type - if keyword_data_source_type == 'database': + if keyword_data_source_type == "database": dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) db.session.commit() else: - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) + storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict['__data__']['table'] + return keyword_table_dict["__data__"]["table"] else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, - keyword_table='', + keyword_table="", data_source_type=keyword_data_source_type, ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) + if keyword_data_source_type == "database": + dataset_keyword_table.keyword_table = json.dumps( + { + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, + }, + cls=SetEncoder, + ) db.session.add(dataset_keyword_table) db.session.commit() @@ -174,9 +173,7 @@ def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> keywords_to_delete = set() for keyword, node_idxs in keyword_table.items(): if node_idxs_to_delete.intersection(node_idxs): - keyword_table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) + keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete) if not keyword_table[keyword]: keywords_to_delete.add(keyword) @@ -202,13 +199,14 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): reverse=True, ) - return sorted_chunk_indices[: k] + return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id == node_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .first() + ) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) @@ -224,14 +222,14 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] - if pre_segment_data['keywords']: - segment.keywords = pre_segment_data['keywords'] - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, - pre_segment_data['keywords']) + segment = pre_segment_data["segment"] + if pre_segment_data["keywords"]: + segment.keywords = pre_segment_data["keywords"] + keyword_table = self._add_text_to_keyword_table( + keyword_table, segment.index_node_id, pre_segment_data["keywords"] + ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ad669ef5150bef..4b1ade8e3fa095 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -8,7 +8,6 @@ class JiebaKeywordTableHandler: - def __init__(self): default_tfidf.stop_words = STOPWORDS @@ -30,4 +29,4 @@ def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index c616a15cf0c20f..9abe78d6ef7e8d 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -1,90 +1,1380 @@ STOPWORDS = { - "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", - "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", - "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", - "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", - "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", - "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", - "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", - "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", - "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", - "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", - "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", - "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", - "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", - "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", - "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", - "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", - "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", - "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", - "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", - "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", - "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", - "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", - "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", - "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", - "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", - "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", - "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", - "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", - "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", - "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", - "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", - "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", - "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", - "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", - "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", - "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", - "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", - "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", - "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", - "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", - "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", - "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", - "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", - "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", - "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", - " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", - "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", - "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", - "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", - "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", - "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", - "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", - "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", - "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", - "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", - "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", - "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", - "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", - "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", - "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", - "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", - "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", - "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", - "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", - "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", - "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", - "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", - "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", - "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", - "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", - "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", - "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", - "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", - "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", - "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", - "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", - "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", - "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", - "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", - "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", - "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", - "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", - "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", - "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", - "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", - "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", - "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", - "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" + "during", + "when", + "but", + "then", + "further", + "isn", + "mustn't", + "until", + "own", + "i", + "couldn", + "y", + "only", + "you've", + "ours", + "who", + "where", + "ourselves", + "has", + "to", + "was", + "didn't", + "themselves", + "if", + "against", + "through", + "her", + "an", + "your", + "can", + "those", + "didn", + "about", + "aren't", + "shan't", + "be", + "not", + "these", + "again", + "so", + "t", + "theirs", + "weren", + "won't", + "won", + "itself", + "just", + "same", + "while", + "why", + "doesn", + "aren", + "him", + "haven", + "for", + "you'll", + "that", + "we", + "am", + "d", + "by", + "having", + "wasn't", + "than", + "weren't", + "out", + "from", + "now", + "their", + "too", + "hadn", + "o", + "needn", + "most", + "it", + "under", + "needn't", + "any", + "some", + "few", + "ll", + "hers", + "which", + "m", + "you're", + "off", + "other", + "had", + "she", + "you'd", + "do", + "you", + "does", + "s", + "will", + "each", + "wouldn't", + "hasn't", + "such", + "more", + "whom", + "she's", + "my", + "yours", + "yourself", + "of", + "on", + "very", + "hadn't", + "with", + "yourselves", + "been", + "ma", + "them", + "mightn't", + "shan", + "mustn", + "they", + "what", + "both", + "that'll", + "how", + "is", + "he", + "because", + "down", + "haven't", + "are", + "no", + "it's", + "our", + "being", + "the", + "or", + "above", + "myself", + "once", + "don't", + "doesn't", + "as", + "nor", + "here", + "herself", + "hasn", + "mightn", + "have", + "its", + "all", + "were", + "ain", + "this", + "at", + "after", + "over", + "shouldn't", + "into", + "before", + "don", + "wouldn", + "re", + "couldn't", + "wasn", + "in", + "should", + "there", + "himself", + "isn't", + "should've", + "doing", + "ve", + "shouldn", + "a", + "did", + "and", + "his", + "between", + "me", + "up", + "below", + "人民", + "末##末", + "啊", + "阿", + "哎", + "哎呀", + "哎哟", + "唉", + "俺", + "俺们", + "按", + "按照", + "吧", + "吧哒", + "把", + "罢了", + "被", + "本", + "本着", + "比", + "比方", + "比如", + "鄙人", + "彼", + "彼此", + "边", + "别", + "别的", + "别说", + "并", + "并且", + "不比", + "不成", + "不单", + "不但", + "不独", + "不管", + "不光", + "不过", + "不仅", + "不拘", + "不论", + "不怕", + "不然", + "不如", + "不特", + "不惟", + "不问", + "不只", + "朝", + "朝着", + "趁", + "趁着", + "乘", + "冲", + "除", + "除此之外", + "除非", + "除了", + "此", + "此间", + "此外", + "从", + "从而", + "打", + "待", + "但", + "但是", + "当", + "当着", + "到", + "得", + "的", + "的话", + "等", + "等等", + "地", + "第", + "叮咚", + "对", + "对于", + "多", + "多少", + "而", + "而况", + "而且", + "而是", + "而外", + "而言", + "而已", + "尔后", + "反过来", + "反过来说", + "反之", + "非但", + "非徒", + "否则", + "嘎", + "嘎登", + "该", + "赶", + "个", + "各", + "各个", + "各位", + "各种", + "各自", + "给", + "根据", + "跟", + "故", + "故此", + "固然", + "关于", + "管", + "归", + "果然", + "果真", + "过", + "哈", + "哈哈", + "呵", + "和", + "何", + "何处", + "何况", + "何时", + "嘿", + "哼", + "哼唷", + "呼哧", + "乎", + "哗", + "还是", + "还有", + "换句话说", + "换言之", + "或", + "或是", + "或者", + "极了", + "及", + "及其", + "及至", + "即", + "即便", + "即或", + "即令", + "即若", + "即使", + "几", + "几时", + "己", + "既", + "既然", + "既是", + "继而", + "加之", + "假如", + "假若", + "假使", + "鉴于", + "将", + "较", + "较之", + "叫", + "接着", + "结果", + "借", + "紧接着", + "进而", + "尽", + "尽管", + "经", + "经过", + "就", + "就是", + "就是说", + "据", + "具体地说", + "具体说来", + "开始", + "开外", + "靠", + "咳", + "可", + "可见", + "可是", + "可以", + "况且", + "啦", + "来", + "来着", + "离", + "例如", + "哩", + "连", + "连同", + "两者", + "了", + "临", + "另", + "另外", + "另一方面", + "论", + "嘛", + "吗", + "慢说", + "漫说", + "冒", + "么", + "每", + "每当", + "们", + "莫若", + "某", + "某个", + "某些", + "拿", + "哪", + "哪边", + "哪儿", + "哪个", + "哪里", + "哪年", + "哪怕", + "哪天", + "哪些", + "哪样", + "那", + "那边", + "那儿", + "那个", + "那会儿", + "那里", + "那么", + "那么些", + "那么样", + "那时", + "那些", + "那样", + "乃", + "乃至", + "呢", + "能", + "你", + "你们", + "您", + "宁", + "宁可", + "宁肯", + "宁愿", + "哦", + "呕", + "啪达", + "旁人", + "呸", + "凭", + "凭借", + "其", + "其次", + "其二", + "其他", + "其它", + "其一", + "其余", + "其中", + "起", + "起见", + "岂但", + "恰恰相反", + "前后", + "前者", + "且", + "然而", + "然后", + "然则", + "让", + "人家", + "任", + "任何", + "任凭", + "如", + "如此", + "如果", + "如何", + "如其", + "如若", + "如上所述", + "若", + "若非", + "若是", + "啥", + "上下", + "尚且", + "设若", + "设使", + "甚而", + "甚么", + "甚至", + "省得", + "时候", + "什么", + "什么样", + "使得", + "是", + "是的", + "首先", + "谁", + "谁知", + "顺", + "顺着", + "似的", + "虽", + "虽然", + "虽说", + "虽则", + "随", + "随着", + "所", + "所以", + "他", + "他们", + "他人", + "它", + "它们", + "她", + "她们", + "倘", + "倘或", + "倘然", + "倘若", + "倘使", + "腾", + "替", + "通过", + "同", + "同时", + "哇", + "万一", + "往", + "望", + "为", + "为何", + "为了", + "为什么", + "为着", + "喂", + "嗡嗡", + "我", + "我们", + "呜", + "呜呼", + "乌乎", + "无论", + "无宁", + "毋宁", + "嘻", + "吓", + "相对而言", + "像", + "向", + "向着", + "嘘", + "呀", + "焉", + "沿", + "沿着", + "要", + "要不", + "要不然", + "要不是", + "要么", + "要是", + "也", + "也罢", + "也好", + "一", + "一般", + "一旦", + "一方面", + "一来", + "一切", + "一样", + "一则", + "依", + "依照", + "矣", + "以", + "以便", + "以及", + "以免", + "以至", + "以至于", + "以致", + "抑或", + "因", + "因此", + "因而", + "因为", + "哟", + "用", + "由", + "由此可见", + "由于", + "有", + "有的", + "有关", + "有些", + "又", + "于", + "于是", + "于是乎", + "与", + "与此同时", + "与否", + "与其", + "越是", + "云云", + "哉", + "再说", + "再者", + "在", + "在下", + "咱", + "咱们", + "则", + "怎", + "怎么", + "怎么办", + "怎么样", + "怎样", + "咋", + "照", + "照着", + "者", + "这", + "这边", + "这儿", + "这个", + "这会儿", + "这就是说", + "这里", + "这么", + "这么点儿", + "这么些", + "这么样", + "这时", + "这些", + "这样", + "正如", + "吱", + "之", + "之类", + "之所以", + "之一", + "只是", + "只限", + "只要", + "只有", + "至", + "至于", + "诸位", + "着", + "着呢", + "自", + "自从", + "自个儿", + "自各儿", + "自己", + "自家", + "自身", + "综上所述", + "总的来看", + "总的来说", + "总的说来", + "总而言之", + "总之", + "纵", + "纵令", + "纵然", + "纵使", + "遵照", + "作为", + "兮", + "呃", + "呗", + "咚", + "咦", + "喏", + "啐", + "喔唷", + "嗬", + "嗯", + "嗳", + "~", + "!", + ".", + ":", + '"', + "'", + "(", + ")", + "*", + "A", + "白", + "社会主义", + "--", + "..", + ">>", + " [", + " ]", + "", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "+", + "=", + "&", + "^", + "%", + "#", + "@", + "`", + ";", + "$", + "(", + ")", + "——", + "—", + "¥", + "·", + "...", + "‘", + "’", + "〉", + "〈", + "…", + " ", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "零", + ">", + "<", + "@", + "#", + "$", + "%", + "︿", + "&", + "*", + "+", + "~", + "|", + "[", + "]", + "{", + "}", + "啊哈", + "啊呀", + "啊哟", + "挨次", + "挨个", + "挨家挨户", + "挨门挨户", + "挨门逐户", + "挨着", + "按理", + "按期", + "按时", + "按说", + "暗地里", + "暗中", + "暗自", + "昂然", + "八成", + "白白", + "半", + "梆", + "保管", + "保险", + "饱", + "背地里", + "背靠背", + "倍感", + "倍加", + "本人", + "本身", + "甭", + "比起", + "比如说", + "比照", + "毕竟", + "必", + "必定", + "必将", + "必须", + "便", + "别人", + "并非", + "并肩", + "并没", + "并没有", + "并排", + "并无", + "勃然", + "不", + "不必", + "不常", + "不大", + "不但...而且", + "不得", + "不得不", + "不得了", + "不得已", + "不迭", + "不定", + "不对", + "不妨", + "不管怎样", + "不会", + "不仅...而且", + "不仅仅", + "不仅仅是", + "不经意", + "不可开交", + "不可抗拒", + "不力", + "不了", + "不料", + "不满", + "不免", + "不能不", + "不起", + "不巧", + "不然的话", + "不日", + "不少", + "不胜", + "不时", + "不是", + "不同", + "不能", + "不要", + "不外", + "不外乎", + "不下", + "不限", + "不消", + "不已", + "不亦乐乎", + "不由得", + "不再", + "不择手段", + "不怎么", + "不曾", + "不知不觉", + "不止", + "不止一次", + "不至于", + "才", + "才能", + "策略地", + "差不多", + "差一点", + "常", + "常常", + "常言道", + "常言说", + "常言说得好", + "长此下去", + "长话短说", + "长期以来", + "长线", + "敞开儿", + "彻夜", + "陈年", + "趁便", + "趁机", + "趁热", + "趁势", + "趁早", + "成年", + "成年累月", + "成心", + "乘机", + "乘胜", + "乘势", + "乘隙", + "乘虚", + "诚然", + "迟早", + "充分", + "充其极", + "充其量", + "抽冷子", + "臭", + "初", + "出", + "出来", + "出去", + "除此", + "除此而外", + "除此以外", + "除开", + "除去", + "除却", + "除外", + "处处", + "川流不息", + "传", + "传说", + "传闻", + "串行", + "纯", + "纯粹", + "此后", + "此中", + "次第", + "匆匆", + "从不", + "从此", + "从此以后", + "从古到今", + "从古至今", + "从今以后", + "从宽", + "从来", + "从轻", + "从速", + "从头", + "从未", + "从无到有", + "从小", + "从新", + "从严", + "从优", + "从早到晚", + "从中", + "从重", + "凑巧", + "粗", + "存心", + "达旦", + "打从", + "打开天窗说亮话", + "大", + "大不了", + "大大", + "大抵", + "大都", + "大多", + "大凡", + "大概", + "大家", + "大举", + "大略", + "大面儿上", + "大事", + "大体", + "大体上", + "大约", + "大张旗鼓", + "大致", + "呆呆地", + "带", + "殆", + "待到", + "单", + "单纯", + "单单", + "但愿", + "弹指之间", + "当场", + "当儿", + "当即", + "当口儿", + "当然", + "当庭", + "当头", + "当下", + "当真", + "当中", + "倒不如", + "倒不如说", + "倒是", + "到处", + "到底", + "到了儿", + "到目前为止", + "到头", + "到头来", + "得起", + "得天独厚", + "的确", + "等到", + "叮当", + "顶多", + "定", + "动不动", + "动辄", + "陡然", + "都", + "独", + "独自", + "断然", + "顿时", + "多次", + "多多", + "多多少少", + "多多益善", + "多亏", + "多年来", + "多年前", + "而后", + "而论", + "而又", + "尔等", + "二话不说", + "二话没说", + "反倒", + "反倒是", + "反而", + "反手", + "反之亦然", + "反之则", + "方", + "方才", + "方能", + "放量", + "非常", + "非得", + "分期", + "分期分批", + "分头", + "奋勇", + "愤然", + "风雨无阻", + "逢", + "弗", + "甫", + "嘎嘎", + "该当", + "概", + "赶快", + "赶早不赶晚", + "敢", + "敢情", + "敢于", + "刚", + "刚才", + "刚好", + "刚巧", + "高低", + "格外", + "隔日", + "隔夜", + "个人", + "各式", + "更", + "更加", + "更进一步", + "更为", + "公然", + "共", + "共总", + "够瞧的", + "姑且", + "古来", + "故而", + "故意", + "固", + "怪", + "怪不得", + "惯常", + "光", + "光是", + "归根到底", + "归根结底", + "过于", + "毫不", + "毫无", + "毫无保留地", + "毫无例外", + "好在", + "何必", + "何尝", + "何妨", + "何苦", + "何乐而不为", + "何须", + "何止", + "很", + "很多", + "很少", + "轰然", + "后来", + "呼啦", + "忽地", + "忽然", + "互", + "互相", + "哗啦", + "话说", + "还", + "恍然", + "会", + "豁然", + "活", + "伙同", + "或多或少", + "或许", + "基本", + "基本上", + "基于", + "极", + "极大", + "极度", + "极端", + "极力", + "极其", + "极为", + "急匆匆", + "即将", + "即刻", + "即是说", + "几度", + "几番", + "几乎", + "几经", + "既...又", + "继之", + "加上", + "加以", + "间或", + "简而言之", + "简言之", + "简直", + "见", + "将才", + "将近", + "将要", + "交口", + "较比", + "较为", + "接连不断", + "接下来", + "皆可", + "截然", + "截至", + "藉以", + "借此", + "借以", + "届时", + "仅", + "仅仅", + "谨", + "进来", + "进去", + "近", + "近几年来", + "近来", + "近年来", + "尽管如此", + "尽可能", + "尽快", + "尽量", + "尽然", + "尽如人意", + "尽心竭力", + "尽心尽力", + "尽早", + "精光", + "经常", + "竟", + "竟然", + "究竟", + "就此", + "就地", + "就算", + "居然", + "局外", + "举凡", + "据称", + "据此", + "据实", + "据说", + "据我所知", + "据悉", + "具体来说", + "决不", + "决非", + "绝", + "绝不", + "绝顶", + "绝对", + "绝非", + "均", + "喀", + "看", + "看来", + "看起来", + "看上去", + "看样子", + "可好", + "可能", + "恐怕", + "快", + "快要", + "来不及", + "来得及", + "来讲", + "来看", + "拦腰", + "牢牢", + "老", + "老大", + "老老实实", + "老是", + "累次", + "累年", + "理当", + "理该", + "理应", + "历", + "立", + "立地", + "立刻", + "立马", + "立时", + "联袂", + "连连", + "连日", + "连日来", + "连声", + "连袂", + "临到", + "另方面", + "另行", + "另一个", + "路经", + "屡", + "屡次", + "屡次三番", + "屡屡", + "缕缕", + "率尔", + "率然", + "略", + "略加", + "略微", + "略为", + "论说", + "马上", + "蛮", + "满", + "没", + "没有", + "每逢", + "每每", + "每时每刻", + "猛然", + "猛然间", + "莫", + "莫不", + "莫非", + "莫如", + "默默地", + "默然", + "呐", + "那末", + "奈", + "难道", + "难得", + "难怪", + "难说", + "内", + "年复一年", + "凝神", + "偶而", + "偶尔", + "怕", + "砰", + "碰巧", + "譬如", + "偏偏", + "乒", + "平素", + "颇", + "迫于", + "扑通", + "其后", + "其实", + "奇", + "齐", + "起初", + "起来", + "起首", + "起头", + "起先", + "岂", + "岂非", + "岂止", + "迄", + "恰逢", + "恰好", + "恰恰", + "恰巧", + "恰如", + "恰似", + "千", + "千万", + "千万千万", + "切", + "切不可", + "切莫", + "切切", + "切勿", + "窃", + "亲口", + "亲身", + "亲手", + "亲眼", + "亲自", + "顷", + "顷刻", + "顷刻间", + "顷刻之间", + "请勿", + "穷年累月", + "取道", + "去", + "权时", + "全都", + "全力", + "全年", + "全然", + "全身心", + "然", + "人人", + "仍", + "仍旧", + "仍然", + "日复一日", + "日见", + "日渐", + "日益", + "日臻", + "如常", + "如此等等", + "如次", + "如今", + "如期", + "如前所述", + "如上", + "如下", + "汝", + "三番两次", + "三番五次", + "三天两头", + "瑟瑟", + "沙沙", + "上", + "上来", + "上去", + "一个", + "月", + "日", + "\n", } diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b77c6562b25ce3..27e4f383ad6a96 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -8,7 +8,6 @@ class BaseKeyword(ABC): - def __init__(self, dataset: Dataset): self.dataset = dataset @@ -31,15 +30,12 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete(self) -> None: raise NotImplementedError - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -47,4 +43,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 6ac610f82b45ba..3c99f33be61e36 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -20,9 +20,7 @@ def _init_keyword(self) -> BaseKeyword: raise ValueError("Keyword store must be specified.") if keyword_type == "jieba": - return Jieba( - dataset=self._dataset - ) + return Jieba(dataset=self._dataset) else: raise ValueError(f"Keyword store {keyword_type} is not supported.") @@ -41,10 +39,7 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete(self) -> None: self._keyword_processor.delete() - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: return self._keyword_processor.search(query, **kwargs) def __getattr__(self, name): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0dac9bfae64107..afac1bf30086e9 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -12,73 +12,83 @@ from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class RetrievalService: - @classmethod - def retrieve(cls, retrieval_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword - if retrieval_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic if RetrievalMethod.is_support_semantic_search(retrieval_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrieval_method': retrieval_method, - 'exceptions': exceptions, - }) + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text if RetrievalMethod.is_support_fulltext_search(retrieval_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrieval_method': retrieval_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -86,110 +96,117 @@ def retrieve(cls, retrieval_method: str, dataset_id: str, query: str, thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list, exceptions: list): + def keyword_search( + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword(dataset=dataset) - documents = keyword.search( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def embedding_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) documents = vector.search_by_vector( cls.escape_query_for_search(query), - search_type='similarity_score_threshold', + search_type="similarity_score_threshold", top_k=top_k, score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + filter={"group_id": [dataset.id]}, ) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def full_text_index_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() vector_processor = Vector( dataset=dataset, ) - documents = vector_processor.search_by_full_text( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: @@ -197,4 +214,4 @@ def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, @staticmethod def escape_query_for_search(query: str) -> str: - return query.replace('"', '\\"') \ No newline at end of file + return query.replace('"', '\\"') diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b78e2a59b1eb6f..a9c0eefb786f34 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel): namespace_password: str = (None,) metrics: str = ("cosine",) read_timeout: int = 60000 + def to_analyticdb_client_params(self): return { "access_key_id": self.access_key_id, @@ -37,6 +38,7 @@ def to_analyticdb_client_params(self): "read_timeout": self.read_timeout, } + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -57,9 +59,7 @@ def __init__(self, collection_name: str, config: AnalyticdbConfig): except: raise ImportError(_import_err_msg) self.config = config - self._client_config = open_api_models.Config( - user_agent="dify", **config.to_analyticdb_client_params() - ) + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) self._client = Client(self._client_config) self._initialize() AnalyticdbVector._init = True @@ -77,6 +77,7 @@ def _initialize(self) -> None: def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -88,6 +89,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + try: request = gpdb_20160503_models.DescribeNamespaceRequest( dbinstance_id=self.config.instance_id, @@ -109,13 +111,12 @@ def _create_namespace_if_not_exists(self) -> None: ) self._client.create_namespace(request) else: - raise ValueError( - f"failed to create namespace {self.config.namespace}: {e}" - ) + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -149,9 +150,7 @@ def _create_collection_if_not_exists(self, embedding_dimension: int): ) self._client.create_collection(request) else: - raise ValueError( - f"failed to create collection {self._collection_name}: {e}" - ) + raise ValueError(f"failed to create collection {self._collection_name}: {e}") redis_client.set(collection_exist_cache_key, 1, ex=3600) def get_type(self) -> str: @@ -162,10 +161,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self._create_collection_if_not_exists(dimension) self.add_texts(texts, embeddings) - def add_texts( - self, documents: list[Document], embeddings: list[list[float]], **kwargs - ): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): metadata = { @@ -191,6 +189,7 @@ def add_texts( def text_exists(self, id: str) -> bool: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -202,13 +201,14 @@ def text_exists(self, id: str) -> bool: vector=None, content=None, top_k=1, - filter=f"ref_doc_id='{id}'" + filter=f"ref_doc_id='{id}'", ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + ids_str = ",".join(f"'{id}'" for id in ids) ids_str = f"({ids_str})" request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -224,6 +224,7 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -235,15 +236,10 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: ) self._client.delete_collection_data(request) - def search_by_vector( - self, query_vector: list[float], **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -270,11 +266,8 @@ def search_by_vector( def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -304,6 +297,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def delete(self) -> None: try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( collection=self._collection_name, dbinstance_id=self.config.instance_id, @@ -315,19 +309,16 @@ def delete(self) -> None: except Exception as e: raise e + class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"][ - "class_prefix" - ] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) - ) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) # handle optional params if dify_config.ANALYTICDB_KEY_ID is None: diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3629887b448aeb..cb38cf94a96e8f 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -27,21 +27,20 @@ def to_chroma_params(self): settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, - chroma_client_auth_credentials=self.auth_credentials + chroma_client_auth_credentials=self.auth_credentials, ) return { - 'host': self.host, - 'port': self.port, - 'ssl': False, - 'tenant': self.tenant, - 'database': self.database, - 'settings': settings, + "host": self.host, + "port": self.port, + "ssl": False, + "tenant": self.tenant, + "database": self.database, + "settings": settings, } class ChromaVector(BaseVector): - def __init__(self, collection_name: str, config: ChromaConfig): super().__init__(collection_name) self._client_config = config @@ -58,9 +57,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) @@ -76,7 +75,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {'$eq': value}}) + collection.delete(where={key: {"$eq": value}}) def delete(self): self._client.delete_collection(self._collection_name) @@ -93,26 +92,26 @@ def text_exists(self, id: str) -> bool: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 - ids: list[str] = results['ids'][0] - documents: list[str] = results['documents'][0] - metadatas: dict[str, Any] = results['metadatas'][0] - distances: list[float] = results['distances'][0] + ids: list[str] = results["ids"][0] + documents: list[str] = results["documents"][0] + metadatas: dict[str, Any] = results["metadatas"][0] + distances: list[float] = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] metadata = metadatas[index] if distance >= score_threshold: - metadata['score'] = distance + metadata["score"] = distance doc = Document( page_content=documents[index], metadata=metadata, ) docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,15 +122,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": VectorType.CHROMA, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 233539756fa84f..76c808f76e5aa9 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel): username: str password: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PORT is required") - if not values['username']: + if not values["username"]: raise ValueError("config USERNAME is required") - if not values['password']: + if not values["password"]: raise ValueError("config PASSWORD is required") return values @@ -50,10 +50,10 @@ def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: lis def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ['http', 'https']: - hosts = f'{config.host}:{config.port}' + if parsed_url.scheme in ["http", "https"]: + hosts = f"{config.host}:{config.port}" else: - hosts = f'http://{config.host}:{config.port}' + hosts = f"http://{config.host}:{config.port}" client = Elasticsearch( hosts=hosts, basic_auth=(config.username, config.password), @@ -68,25 +68,27 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _get_version(self) -> str: info = self._client.info() - return info['version']['number'] + return info["version"]["number"] def _check_version(self): - if self._version < '8.0.0': + if self._version < "8.0.0": raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return 'elasticsearch' + return "elasticsearch" def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): - self._client.index(index=self._collection_name, - id=uuids[i], - document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] if embeddings[i] else None, - Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} - }) + self._client.index( + index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, + }, + ) self._client.indices.refresh(index=self._collection_name) return uuids @@ -98,15 +100,9 @@ def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(index=self._collection_name, id=id) def delete_by_metadata_field(self, key: str, value: str) -> None: - query_str = { - 'query': { - 'match': { - f'metadata.{key}': f'{value}' - } - } - } + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit['_id'] for hit in results['hits']['hits']] + ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) @@ -115,44 +111,44 @@ def delete(self) -> None: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 10) - knn = { - "field": Field.VECTOR.value, - "query_vector": query_vector, - "k": top_k - } + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: docs_and_scores.append( - (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = { - "match": { - Field.CONTENT_KEY.value: query - } - } + query_str = {"match": {Field.CONTENT_KEY.value: query}} results = self._client.search(index=self._collection_name, query=query_str) docs = [] - for hit in results['hits']['hits']: - docs.append(Document( - page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value], - )) + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) return docs @@ -162,11 +158,11 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name}' + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name} already exists.") return @@ -179,14 +175,14 @@ def create_collection( Field.VECTOR.value: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, - "similarity": "cosine" + "similarity": "cosine", }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } } self._client.indices.create(index=self._collection_name, mappings=mappings) @@ -197,22 +193,21 @@ def create_collection( class ElasticSearchVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get('ELASTICSEARCH_HOST'), - port=config.get('ELASTICSEARCH_PORT'), - username=config.get('ELASTICSEARCH_USERNAME'), - password=config.get('ELASTICSEARCH_PASSWORD'), + host=config.get("ELASTICSEARCH_HOST"), + port=config.get("ELASTICSEARCH_PORT"), + username=config.get("ELASTICSEARCH_USERNAME"), + password=config.get("ELASTICSEARCH_PASSWORD"), ), - attributes=[] + attributes=[], ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index c1c73d1c0dadf5..1d08046641f3e4 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -27,44 +27,39 @@ class MilvusConfig(BaseModel): batch_size: int = 100 database: str = "default" - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values.get('uri'): + if not values.get("uri"): raise ValueError("config MILVUS_URI is required") - if not values.get('user'): + if not values.get("user"): raise ValueError("config MILVUS_USER is required") - if not values.get('password'): + if not values.get("password"): raise ValueError("config MILVUS_PASSWORD is required") return values def to_milvus_params(self): return { - 'uri': self.uri, - 'token': self.token, - 'user': self.user, - 'password': self.password, - 'db_name': self.database, + "uri": self.uri, + "token": self.token, + "user": self.user, + "password": self.password, + "db_name": self.database, } class MilvusVector(BaseVector): - def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = 'Session' + self._consistency_level = "Session" self._fields = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } + index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} metadatas = [d.metadata for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -75,7 +70,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata + Field.METADATA_KEY.value: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -84,22 +79,20 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i:i + 1000] + batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks def get_ids_by_metadata_field(self, key: str, value: str): - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["{key}"] == "{value}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] + ) if result: return [item["id"] for item in result] else: @@ -107,17 +100,15 @@ def get_ids_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str): if self._client.has_collection(self._collection_name): - ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: if self._client.has_collection(self._collection_name): - - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {ids}', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] + ) if result: ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) @@ -130,29 +121,28 @@ def text_exists(self, id: str) -> bool: if not self._client.has_collection(self._collection_name): return False - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] == "{id}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"] + ) return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. - results = self._client.search(collection_name=self._collection_name, - data=[query_vector], - limit=kwargs.get('top_k', 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], - ) + results = self._client.search( + collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) # Organize results. docs = [] for result in results[0]: - metadata = result['entity'].get(Field.METADATA_KEY.value) - metadata['score'] = result['distance'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if result['distance'] > score_threshold: - doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), - metadata=metadata) + metadata = result["entity"].get(Field.METADATA_KEY.value) + metadata["score"] = result["distance"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -161,11 +151,11 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists @@ -180,19 +170,11 @@ def create_collection( fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) + fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -208,9 +190,12 @@ def create_collection( # Create the collection collection_name = self._collection_name - self._client.create_collection(collection_name=collection_name, - schema=schema, index_params=index_params_obj, - consistency_level=self._consistency_level) + self._client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params_obj, + consistency_level=self._consistency_level, + ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: @@ -221,13 +206,12 @@ def _init_client(self, config) -> MilvusClient: class MilvusVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) return MilvusVector( collection_name=collection_name, @@ -237,5 +221,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings user=dify_config.MILVUS_USER, password=dify_config.MILVUS_PASSWORD, database=dify_config.MILVUS_DATABASE, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 05e75effefb23a..90464ac42a86d9 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -31,7 +31,6 @@ class SortOrder(Enum): class MyScaleVector(BaseVector): - def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config @@ -80,7 +79,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** doc_id, self.escape_str(doc.page_content), embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {} + json.dumps(doc.metadata) if doc.metadata else {}, ) values.append(str(row)) ids.append(doc_id) @@ -101,7 +100,8 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}") + f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + ) def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( @@ -122,9 +122,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get('score_threshold') or 0.0 - where_str = f"WHERE dist < {1 - score_threshold}" if \ - self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" + score_threshold = kwargs.get("score_threshold") or 0.0 + where_str = ( + f"WHERE dist < {1 - score_threshold}" + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 + else "" + ) sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} @@ -133,7 +136,7 @@ def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: return [ Document( page_content=r["text"], - vector=r['vector'], + vector=r["vector"], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() @@ -149,13 +152,12 @@ def delete(self) -> None: class MyScaleVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) return MyScaleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index c95d202173b84d..ecd7e0271cefe5 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel): password: Optional[str] = None secure: bool = False - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values.get('host'): + if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") - if not values.get('port'): + if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") return values @@ -44,19 +44,18 @@ def create_ssl_context(self) -> ssl.SSLContext: def to_opensearch_params(self) -> dict[str, Any]: params = { - 'hosts': [{'host': self.host, 'port': self.port}], - 'use_ssl': self.secure, - 'verify_certs': self.secure, + "hosts": [{"host": self.host, "port": self.port}], + "use_ssl": self.secure, + "verify_certs": self.secure, } if self.user and self.password: - params['http_auth'] = (self.user, self.password) + params["http_auth"] = (self.user, self.password) if self.secure: - params['ssl_context'] = self.create_ssl_context() + params["ssl_context"] = self.create_ssl_context() return params class OpenSearchVector(BaseVector): - def __init__(self, collection_name: str, config: OpenSearchConfig): super().__init__(collection_name) self._client_config = config @@ -81,7 +80,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, - } + }, } actions.append(action) @@ -90,8 +89,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] else: return None @@ -110,7 +109,7 @@ def delete_by_ids(self, ids: list[str]) -> None: actual_ids = [] for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + es_ids = self.get_ids_by_metadata_field("doc_id", doc_id) if es_ids: actual_ids.extend(es_ids) else: @@ -122,9 +121,9 @@ def delete_by_ids(self, ids: list[str]) -> None: helpers.bulk(self._client, actions) except BulkIndexError as e: for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') + delete_error = error.get("delete", {}) + status = delete_error.get("status") + doc_id = delete_error.get("_id") if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") @@ -151,15 +150,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc raise ValueError("All elements in query_vector should be floats") query = { - "size": kwargs.get('top_k', 4), - "query": { - "knn": { - Field.VECTOR.value: { - Field.VECTOR.value: query_vector, - "k": kwargs.get('top_k', 4) - } - } - } + "size": kwargs.get("top_k", 4), + "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } try: @@ -169,17 +161,17 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc raise docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) # Make sure metadata is a dictionary if metadata is None: metadata = {} - metadata['score'] = hit['_score'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if hit['_score'] > score_threshold: - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + metadata["score"] = hit["_score"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if hit["_score"] > score_threshold: + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -190,32 +182,28 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: response = self._client.search(index=self._collection_name.lower(), body=full_text_query) docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value) - vector = hit['_source'].get(Field.VECTOR.value) - page_content = hit['_source'].get(Field.CONTENT_KEY.value) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value) + vector = hit["_source"].get(Field.VECTOR.value) + page_content = hit["_source"].get(Field.CONTENT_KEY.value) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name.lower()} already exists.") return if not self._client.indices.exists(index=self._collection_name.lower()): index_body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { Field.CONTENT_KEY.value: {"type": "text"}, @@ -226,20 +214,17 @@ def create_collection( "name": "hnsw", "space_type": "l2", "engine": "faiss", - "parameters": { - "ef_construction": 64, - "m": 8 - } - } + "parameters": {"ef_construction": 64, "m": 8}, + }, }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } - } + }, } self._client.indices.create(index=self._collection_name.lower(), body=index_body) @@ -248,17 +233,14 @@ def create_collection( class OpenSearchVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST, @@ -268,7 +250,4 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings secure=dify_config.OPENSEARCH_SECURE, ) - return OpenSearchVector( - collection_name=collection_name, - config=open_search_config - ) + return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index aa2c6171c33367..eb2e3e0a8ca4ee 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") @@ -103,9 +103,16 @@ def output_type_handler(self, cursor, metadata): arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - def _create_connection_pool(self, config: OracleVectorConfig): - return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1) + def _create_connection_pool(self, config: OracleVectorConfig): + return oracledb.create_pool( + user=config.user, + password=config.password, + dsn="{}:{}/{}".format(config.host, config.port, config.database), + min=1, + max=50, + increment=1, + ) @contextmanager def _get_cursor(self): @@ -136,13 +143,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** doc_id, doc.page_content, json.dumps(doc.metadata), - #array.array("f", embeddings[i]), + # array.array("f", embeddings[i]), numpy.array(embeddings[i]), ) ) - #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") + # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: - cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values) + cur.executemany( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + ) return pks def text_exists(self, id: str) -> bool: @@ -157,7 +166,8 @@ def get_by_ids(self, ids: list[str]) -> list[Document]: for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) return docs - #def get_ids_by_metadata_field(self, key: str, value: str): + + # def get_ids_by_metadata_field(self, key: str, value: str): # with self._get_cursor() as cur: # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) # idss = [] @@ -184,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc top_k = kwargs.get("top_k", 5) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)] + f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only", + [numpy.array(query_vector)], ) docs = [] score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 @@ -202,7 +213,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if len(query) > 0: # Check which language the query is in - zh_pattern = re.compile('[\u4e00-\u9fa5]+') + zh_pattern = re.compile("[\u4e00-\u9fa5]+") match = zh_pattern.search(query) entities = [] # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split @@ -210,7 +221,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 + if ( + pos == "nr" + or pos == "Ng" + or pos == "eng" + or pos == "nz" + or pos == "n" + or pos == "ORG" + or pos == "v" + ): # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -220,22 +239,22 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: entities.append(current_entity) else: try: - nltk.data.find('tokenizers/punkt') - nltk.data.find('corpora/stopwords') + nltk.data.find("tokenizers/punkt") + nltk.data.find("corpora/stopwords") except LookupError: - nltk.download('punkt') - nltk.download('stopwords') + nltk.download("punkt") + nltk.download("stopwords") print("run download") - e_str = re.sub(r'[^\w ]', '', query) + e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) - stop_words = stopwords.words('english') + stop_words = stopwords.words("english") for token in all_tokens: if token not in stop_words: entities.append(token) with self._get_cursor() as cur: cur.execute( f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)] + [" ACCUM ".join(entities)], ) docs = [] for record in cur: @@ -273,8 +292,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) return OracleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a48224070fdc21..b778582e8a7f9f 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PGVECTO_RS_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config PGVECTO_RS_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config PGVECTO_RS_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config PGVECTO_RS_DATABASE is required") return values class PGVectoRS(BaseVector): - def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): super().__init__(collection_name) self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self._client = create_engine(self._url) with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) @@ -80,9 +81,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -133,9 +134,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " - ) + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") result = session.execute(select_statement).fetchall() if result: return [item[0] for item in result] @@ -143,12 +142,11 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete_by_ids(self, ids: list[str]) -> None: @@ -156,13 +154,13 @@ def delete_by_ids(self, ids: list[str]) -> None: select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " ) - result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete(self) -> None: @@ -187,7 +185,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc query_vector, ).label("distance"), ) - .limit(kwargs.get('top_k', 2)) + .limit(kwargs.get("top_k", 2)) .order_by("distance") ) res = session.execute(stmt) @@ -198,11 +196,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for record, dis in results: metadata = record.meta score = 1 - dis - metadata['score'] = score - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + metadata["score"] = score + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if score > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) + doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs @@ -225,13 +222,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( @@ -243,5 +239,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings password=dify_config.PGVECTO_RS_PASSWORD, database=dify_config.PGVECTO_RS_DATABASE, ), - dim=dim + dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c9f2f35af03d75..b01cd91e07c309 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") @@ -201,8 +201,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) return PGVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 297bff928e8ae8..83d561819ca7b8 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -48,28 +48,25 @@ class QdrantConfig(BaseModel): prefer_grpc: bool = False def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") if not os.path.isabs(path): path = os.path.join(self.root_path, path) - return { - 'path': path - } + return {"path": path} else: return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout, - 'verify': self.endpoint.startswith('https'), - 'grpc_port': self.grpc_port, - 'prefer_grpc': self.prefer_grpc + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": self.endpoint.startswith("https"), + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, } class QdrantVector(BaseVector): - - def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) @@ -80,10 +77,7 @@ def get_type(self) -> str: return VectorType.QDRANT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -97,9 +91,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -110,12 +104,19 @@ def create_collection(self, collection_name: str, vector_size: int): all_collection_name.append(collection.name) if collection_name not in all_collection_name: from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[self._distance_func], ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) self._client.recreate_collection( collection_name=collection_name, vectors_config=vectors_config, @@ -124,21 +125,24 @@ def create_collection(self, collection_name: str, vector_size: int): ) # create group_id payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) # create doc_id payload index - self._client.create_payload_index(collection_name, Field.DOC_ID.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) # create full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, tokenizer=TokenizerType.MULTILINGUAL, min_token_len=2, max_token_len=20, - lowercase=True + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -147,26 +151,23 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, embeddings, metadatas, uuids, 64, self._group_id - ): - self._client.upsert( - collection_name=self._collection_name, points=points - ) + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) return added_ids def _generate_rest_batches( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest + texts_iterator = iter(texts) embeddings_iterator = iter(embeddings) metadatas_iterator = iter(metadatas or []) @@ -203,13 +204,13 @@ def _generate_rest_batches( @classmethod def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, ) -> list[dict]: payloads = [] for i, text in enumerate(texts): @@ -219,18 +220,11 @@ def _build_payloads( "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) return payloads def delete_by_metadata_field(self, key: str, value: str): - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -248,9 +242,7 @@ def delete_by_metadata_field(self, key: str, value: str): self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -275,9 +267,7 @@ def delete(self): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -288,7 +278,6 @@ def delete(self): raise e def delete_by_ids(self, ids: list[str]) -> None: - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,9 +293,7 @@ def delete_by_ids(self, ids: list[str]) -> None: ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -324,15 +311,13 @@ def text_exists(self, id: str) -> bool: all_collection_name.append(collection.name) if self._collection_name not in all_collection_name: return False - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[id] - ) + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) return len(response) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + filter = models.Filter( must=[ models.FieldCondition( @@ -348,22 +333,22 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", .0) + score_threshold=kwargs.get("score_threshold", 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if result.score > score_threshold: - metadata['score'] = result.score + metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -372,6 +357,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: List of documents most similar to the query text and distance for each. """ from qdrant_client.http import models + scroll_filter = models.Filter( must=[ models.FieldCondition( @@ -381,24 +367,21 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: models.FieldCondition( key="page_content", match=models.MatchText(text=query), - ) + ), ] ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, - limit=kwargs.get('top_k', 2), + limit=kwargs.get("top_k", 2), with_payload=True, - with_vectors=True - + with_vectors=True, ) results = response[0] documents = [] for result in results: if result: - document = self._document_from_scored_point( - result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - ) + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) documents.append(document) return documents @@ -410,10 +393,10 @@ def _reload_if_needed(self): @classmethod def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), @@ -425,24 +408,25 @@ def _document_from_scored_point( class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) if not dataset.index_struct_dict: - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) config = current_app.config return QdrantVector( @@ -454,6 +438,6 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings root_path=config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED - ) + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ), ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 63ad0682d729fe..d8e4ff628c1f65 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -33,28 +33,29 @@ class RelytConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config RELYT_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config RELYT_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config RELYT_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config RELYT_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config RELYT_DATABASE is required") return values class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) self.embedding_dimension = 1536 self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self.client = create_engine(self._url) self._fields = [] self._group_id = group_id @@ -70,9 +71,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -110,7 +111,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** ids = [str(uuid.uuid1()) for _ in documents] metadatas = [d.metadata for d in documents] for metadata in metadatas: - metadata['group_id'] = self._group_id + metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] # Define the table schema @@ -127,9 +128,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** chunks_table_data = [] with self.client.connect() as conn: with conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): chunks_table_data.append( { "id": chunk_id, @@ -196,15 +195,13 @@ def delete_by_uuids(self, ids: list[str] = None): return False def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) def delete_by_ids(self, ids: list[str]) -> None: - with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) @@ -228,38 +225,34 @@ def text_exists(self, id: str) -> bool: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get('top_k')), - embedding=query_vector, - filter=kwargs.get('filter') + k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if 1 - score > score_threshold: docs.append(document) return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: from sqlalchemy.engine import Row except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) + raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + if len(value) > 1 else f"metadata->>{key!r} = {value[0]!r}" for key, value in filter.items() ] @@ -305,13 +298,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name)) return RelytVector( collection_name=collection_name, @@ -322,5 +314,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings password=dify_config.RELYT_PASSWORD, database=dify_config.RELYT_DATABASE, ), - group_id=dataset.id + group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3325a1028ece52..ada0c5cf460f43 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -25,16 +25,11 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = 1, - replicas: int = 2, + shard: int = (1,) + replicas: int = (2,) def to_tencent_params(self): - return { - 'url': self.url, - 'username': self.username, - 'key': self.api_key, - 'timeout': self.timeout - } + return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} class TencentVector(BaseVector): @@ -61,13 +56,10 @@ def _init_database(self): return self._client.create_database(database_name=self._client_config.database) def get_type(self) -> str: - return 'tencent' + return "tencent" def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: collections = self._db.list_collections() @@ -77,9 +69,9 @@ def _has_collection(self) -> bool: return False def _create_collection(self, dimension: int) -> None: - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return @@ -101,9 +93,7 @@ def _create_collection(self, dimension: int) -> None: raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) index = vdb_index.Index( - vdb_index.FilterIndex( - self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY - ), + vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), vdb_index.VectorIndex( self.field_vector, dimension, @@ -111,12 +101,8 @@ def _create_collection(self, dimension: int) -> None: metric_type, params, ), - vdb_index.FilterIndex( - self.field_text, enum.FieldType.String, enum.IndexType.FILTER - ), - vdb_index.FilterIndex( - self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER - ), + vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), + vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), ) self._db.create_collection( @@ -163,15 +149,14 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - - res = self._db.collection(self._collection_name).search(vectors=[query_vector], - params=document.HNSWSearchParams( - ef=kwargs.get("ef", 10)), - retrieve_vector=False, - limit=kwargs.get('top_k', 4), - timeout=self._client_config.timeout, - ) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + res = self._db.collection(self._collection_name).search( + vectors=[query_vector], + params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + timeout=self._client_config.timeout, + ) + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -200,15 +185,13 @@ def delete(self) -> None: class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) return TencentVector( collection_name=collection_name, @@ -220,5 +203,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index d3685c099162fd..0e4b3f67a12e45 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel): database: str program_name: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config TIDB_VECTOR_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config TIDB_VECTOR_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") - if not values['program_name']: + if not values["program_name"]: raise ValueError("config APPLICATION_NAME is required") return values class TiDBVector(BaseVector): - def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType + return Table( self._collection_name, self._orm_base.metadata, - Column('id', String(36), primary_key=True, nullable=False), - Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("id", String(36), primary_key=True, nullable=False), + Column( + "vector", + VectorType(dim), + nullable=False, + comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})", + ), Column("text", TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), - Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), - extend_existing=True + Column( + "update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), + extend_existing=True, ) - def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"): super().__init__(collection_name) self._client_config = config - self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") + self._url = ( + f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}" + ) self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -83,9 +92,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return with Session(self._engine) as session: @@ -116,9 +125,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** chunks_table_data = [] with self._engine.connect() as conn: with conn.begin(): - for id, text, meta, embedding in zip( - ids, texts, metas, embeddings - ): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) # Execute the batch insert when the batch size is reached @@ -133,12 +140,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** return ids def text_exists(self, id: str) -> bool: - result = self.get_ids_by_metadata_field('doc_id', id) + result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ ) @@ -180,20 +187,22 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 - filter = kwargs.get('filter') + filter = kwargs.get("filter") distance = 1 - score_threshold query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" - logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + logger.debug( + f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + ) docs = [] - if self._distance_func == 'l2': - tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'cosine': - tidb_func = 'Vec_Cosine_distance' + if self._distance_func == "l2": + tidb_func = "Vec_l2_distance" + elif self._distance_func == "cosine": + tidb_func = "Vec_Cosine_distance" else: - tidb_func = 'Vec_Cosine_distance' + tidb_func = "Vec_Cosine_distance" with Session(self._engine) as session: select_statement = sql_text( @@ -208,7 +217,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: metadata = json.loads(meta) - metadata['score'] = 1 - distance + metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -224,15 +233,13 @@ def delete(self) -> None: class TiDBVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) return TiDBVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 3f70e8b60814c1..fb80cdec87b53e 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -7,7 +7,6 @@ class BaseVector(ABC): - def __init__(self, collection_name: str): self._collection_name = collection_name @@ -39,18 +38,11 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: raise NotImplementedError @abstractmethod - def search_by_vector( - self, - query_vector: list[float], - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: raise NotImplementedError @abstractmethod - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def delete(self) -> None: @@ -58,7 +50,7 @@ def delete(self) -> None: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -66,7 +58,7 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 627d7c3aeb5a18..7d2db140df2ad1 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -20,17 +20,14 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings @staticmethod def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: - index_struct_dict = { - "type": vector_type, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict class Vector: def __init__(self, dataset: Dataset, attributes: list = None): if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page'] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -39,7 +36,7 @@ def __init__(self, dataset: Dataset, attributes: list = None): def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] + vector_type = self._dataset.index_struct_dict["type"] if not vector_type: raise ValueError("Vector store must be specified.") @@ -52,45 +49,59 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: match vector_type: case VectorType.CHROMA: from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return ChromaVectorFactory case VectorType.MILVUS: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory + return MyScaleVectorFactory case VectorType.PGVECTOR: from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory case VectorType.PGVECTO_RS: from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory case VectorType.QDRANT: from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory case VectorType.ELASTICSEARCH: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory case VectorType.WEAVIATE: from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory + return TencentVectorFactory case VectorType.ORACLE: from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory + return OracleVectorFactory case VectorType.OPENSEARCH: from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case VectorType.ANALYTICDB: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory + return AnalyticdbVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -98,22 +109,14 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: def create(self, texts: list = None, **kwargs): if texts: embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create( - texts=texts, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) def add_texts(self, documents: list[Document], **kwargs): - if kwargs.get('duplicate_check', False): + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) - self._vector_processor.create( - texts=documents, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) @@ -124,24 +127,18 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None: self._vector_processor.delete_by_metadata_field(key, value) - def search_by_vector( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) def delete(self) -> None: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: @@ -151,14 +148,13 @@ def _get_embeddings(self) -> Embeddings: tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model - + model=self._dataset.embedding_model, ) return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 317ca6abc8c89d..ba04ea879d9b43 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,17 +2,17 @@ class VectorType(str, Enum): - ANALYTICDB = 'analyticdb' - CHROMA = 'chroma' - MILVUS = 'milvus' - MYSCALE = 'myscale' - PGVECTOR = 'pgvector' - PGVECTO_RS = 'pgvecto-rs' - QDRANT = 'qdrant' - RELYT = 'relyt' - TIDB_VECTOR = 'tidb_vector' - WEAVIATE = 'weaviate' - OPENSEARCH = 'opensearch' - TENCENT = 'tencent' - ORACLE = 'oracle' - ELASTICSEARCH = 'elasticsearch' + ANALYTICDB = "analyticdb" + CHROMA = "chroma" + MILVUS = "milvus" + MYSCALE = "myscale" + PGVECTOR = "pgvector" + PGVECTO_RS = "pgvecto-rs" + QDRANT = "qdrant" + RELYT = "relyt" + TIDB_VECTOR = "tidb_vector" + WEAVIATE = "weaviate" + OPENSEARCH = "opensearch" + TENCENT = "tencent" + ORACLE = "oracle" + ELASTICSEARCH = "elasticsearch" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 205fe850c35838..750172b015edbf 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel): api_key: Optional[str] = None batch_size: int = 100 - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: + if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values class WeaviateVector(BaseVector): - def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): super().__init__(collection_name) self._client = self._init_client(config) @@ -43,10 +42,7 @@ def _init_client(self, config: WeaviateConfig) -> weaviate.Client: try: client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None + url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) except requests.exceptions.ConnectionError: raise ConnectionError("Vector database connection error") @@ -68,10 +64,10 @@ def get_type(self) -> str: def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + if not class_prefix.endswith("_Node"): # original class_prefix - class_prefix += '_Node' + class_prefix += "_Node" return class_prefix @@ -79,10 +75,7 @@ def get_collection_name(self, dataset: Dataset) -> str: return Dataset.gen_collection_name_by_id(dataset_id) def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): # create collection @@ -91,9 +84,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) @@ -129,17 +122,9 @@ def delete_by_metadata_field(self, key: str, value: str): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - where_filter = { - "operator": "Equal", - "path": [key], - "valueText": value - } - - self._client.batch.delete_objects( - class_name=self._collection_name, - where=where_filter, - output='minimal' - ) + where_filter = {"operator": "Equal", "path": [key], "valueText": value} + + self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") def delete(self): # check whether the index already exists @@ -154,11 +139,19 @@ def text_exists(self, id: str) -> bool: # check whether the index already exists if not self._client.schema.contains(schema): return False - result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": id, - }).with_limit(1).do() + result = ( + self._client.query.get(collection_name) + .with_additional(["id"]) + .with_where( + { + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + } + ) + .with_limit(1) + .do() + ) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -211,13 +204,13 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 # check score threshold if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -240,15 +233,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + properties = ["text"] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - additional = res.pop('_additional') - docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) + additional = res.pop("_additional") + docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: @@ -271,20 +264,19 @@ def _json_serializable(self, value: Any) -> Any: class WeaviateVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT, api_key=dify_config.WEAVIATE_API_KEY, - batch_size=dify_config.WEAVIATE_BATCH_SIZE + batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), - attributes=attributes + attributes=attributes, ) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 96a15be7426c67..0d4dff5b89027c 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -12,10 +12,10 @@ class DatasetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -41,9 +41,9 @@ def user_id(self) -> Any: @property def docs(self) -> dict[str, Document]: - document_segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id - ).all() + document_segments = ( + db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + ) output = {} for document_segment in document_segments: @@ -55,48 +55,45 @@ def docs(self) -> dict[str, Document]: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) return output - def add_documents( - self, docs: Sequence[Document], allow_update: bool = True - ) -> None: - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == self._document_id - ).scalar() + def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == self._document_id) + .scalar() + ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == 'high_quality': + if self._dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model + model=self._dataset.embedding_model, ) for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.metadata['doc_id']} already exists. " - "Set allow_update to True to overwrite." + f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite." ) # calc embedding use tokens if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[doc.page_content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) else: tokens = 0 @@ -107,8 +104,8 @@ def add_documents( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.metadata['doc_id'], - index_node_hash=doc.metadata['doc_hash'], + index_node_id=doc.metadata["doc_id"], + index_node_hash=doc.metadata["doc_hash"], position=max_position, content=doc.page_content, word_count=len(doc.page_content), @@ -116,15 +113,15 @@ def add_documents( enabled=False, created_by=self._user_id, ) - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) else: segment_document.content = doc.page_content - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') - segment_document.index_node_hash = doc.metadata['doc_hash'] + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") + segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens @@ -135,9 +132,7 @@ def document_exists(self, doc_id: str) -> bool: result = self.get_document_segment(doc_id) return result is not None - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -153,7 +148,7 @@ def get_document( "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: @@ -188,9 +183,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: return document_segment.index_node_hash def get_document_segment(self, doc_id: str) -> DocumentSegment: - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id, - DocumentSegment.index_node_id == doc_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .first() + ) return document_segment diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index abfdafcfa251a4..f4c7b4b5f78b7a 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -4,6 +4,7 @@ In addition, content loading code should provide a lazy loading interface by default. """ + from __future__ import annotations import contextlib diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 0470569f393020..5b674039024189 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import csv from typing import Optional @@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, ): """Initialize with file path.""" self._file_path = file_path @@ -57,7 +58,7 @@ def _read_from_file(self, csvfile) -> list[Document]: docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: @@ -67,7 +68,7 @@ def _read_from_file(self, csvfile) -> list[Document]: for i, row in df.iterrows(): content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns) - source = row[self.source_column] if self.source_column else '' + source = row[self.source_column] if self.source_column else "" metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7479b1d97b8fdc..3692b5d19dfb65 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel): """ website import info. """ + provider: str job_id: str url: str @@ -43,6 +45,7 @@ class ExtractSetting(BaseModel): """ Model class for provider response. """ + datasource_type: str upload_file: Optional[UploadFile] = None notion_info: Optional[NotionInfo] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 526c66042ca244..fc331657195454 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import os from typing import Optional @@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" + """Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] file_extension = os.path.splitext(self._file_path)[-1].lower() - if file_extension == '.xlsx': + if file_extension == ".xlsx": wb = load_workbook(self._file_path, data_only=True) for sheet_name in wb.sheetnames: sheet = wb[sheet_name] @@ -44,35 +40,38 @@ def extract(self) -> list[Document]: continue df = pd.DataFrame(data, columns=cols) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for index, row in df.iterrows(): page_content = [] for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): - cell = sheet.cell(row=index + 2, - column=col_index + 1) # +2 to account for header and 1-based index + cell = sheet.cell( + row=index + 2, column=col_index + 1 + ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" page_content.append(f'"{k}":"{value}"') else: page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) - elif file_extension == '.xls': - excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + elif file_extension == ".xls": + excel_file = pd.ExcelFile(self._file_path, engine="xlrd") for sheet_name in excel_file.sheet_names: df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for _, row in df.iterrows(): page_content = [] for k, v in row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) else: raise ValueError(f"Unsupported file extension: {file_extension}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index f7a08135f57518..a00b3cba533f6d 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -29,61 +29,60 @@ from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] +SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"] USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" class ExtractProcessor: @classmethod - def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ - -> Union[list[Document], str]: + def load_from_upload_file( + cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False + ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=upload_file, - document_model='text_model' + datasource_type="upload_file", upload_file=upload_file, document_model="text_model" ) if return_text: - delimiter = '\n' + delimiter = "\n" return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) else: return cls.extract(extract_setting, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={ - "User-Agent": USER_AGENT - }) + response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix - if not suffix and suffix != '.': + if not suffix and suffix != ".": # get content-type - if response.headers.get('Content-Type'): - suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + if response.headers.get("Content-Type"): + suffix = "." + response.headers.get("Content-Type").split("/")[-1] else: - content_disposition = response.headers.get('Content-Disposition') + content_disposition = response.headers.get("Content-Disposition") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: file.write(response.content) - extract_setting = ExtractSetting( - datasource_type="upload_file", - document_model='text_model' - ) + extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: - delimiter = '\n' - return delimiter.join([document.page_content for document in cls.extract( - extract_setting=extract_setting, file_path=file_path)]) + delimiter = "\n" + return delimiter.join( + [ + document.page_content + for document in cls.extract(extract_setting=extract_setting, file_path=file_path) + ] + ) else: return cls.extract(extract_setting=extract_setting, file_path=file_path) @classmethod - def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, - file_path: str = None) -> list[Document]: + def extract( + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None + ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: @@ -96,50 +95,56 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - if etl_type == 'Unstructured': - if file_extension == '.xlsx' or file_extension == '.xls': + if etl_type == "Unstructured": + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: - extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + elif file_extension in [".md", ".markdown"]: + extractor = ( + UnstructuredMarkdownExtractor(file_path, unstructured_api_url) + if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + ) + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == '.msg': + elif file_extension == ".msg": extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) - elif file_extension == '.eml': + elif file_extension == ".eml": extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) - elif file_extension == '.ppt': + elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) - elif file_extension == '.pptx': + elif file_extension == ".pptx": extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) - elif file_extension == '.xml': + elif file_extension == ".xml": extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) else: # txt - extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + extractor = ( + UnstructuredTextExtractor(file_path, unstructured_api_url) + if is_automatic else TextExtractor(file_path, autodetect_encoding=True) + ) else: - if file_extension == '.xlsx' or file_extension == '.xls': + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: + elif file_extension in [".md", ".markdown"]: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path) else: # txt @@ -155,13 +160,13 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: - if extract_setting.website_info.provider == 'firecrawl': + if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, job_id=extract_setting.website_info.job_id, tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content + only_main_content=extract_setting.website_info.only_main_content, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332d237..582eca94df71e1 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -1,12 +1,11 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod class BaseExtractor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 2b85ad9739881e..054ce5f4b2e6c6 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -9,108 +9,98 @@ class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' - if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': - raise ValueError('No API key provided') + self.base_url = base_url or "https://api.firecrawl.dev" + if self.api_key is None and self.base_url == "https://api.firecrawl.dev": + raise ValueError("No API key provided") def scrape_url(self, url, params=None) -> dict: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - json_data = {'url': url} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + json_data = {"url": url} if params: json_data.update(params) - response = requests.post( - f'{self.base_url}/v0/scrape', - headers=headers, - json=json_data - ) + response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: response = response.json() - if response['success'] == True: - data = response['data'] + if response["success"] == True: + data = response["data"] return { - 'title': data.get('metadata').get('title'), - 'description': data.get('metadata').get('description'), - 'source_url': data.get('metadata').get('sourceURL'), - 'markdown': data.get('markdown') + "title": data.get("metadata").get("title"), + "description": data.get("metadata").get("description"), + "source_url": data.get("metadata").get("sourceURL"), + "markdown": data.get("markdown"), } else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') elif response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: headers = self._prepare_headers() - json_data = {'url': url} + json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: - job_id = response.json().get('jobId') + job_id = response.json().get("jobId") return job_id else: - self._handle_error(response, 'start crawl job') + self._handle_error(response, "start crawl job") def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() - response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() - if crawl_status_response.get('status') == 'completed': - total = crawl_status_response.get('total', 0) + if crawl_status_response.get("status") == "completed": + total = crawl_status_response.get("total", 0) if total == 0: - raise Exception('Failed to check crawl status. Error: No page found') - data = crawl_status_response.get('data', []) + raise Exception("Failed to check crawl status. Error: No page found") + data = crawl_status_response.get("data", []) url_data_list = [] for item in data: - if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - 'title': item.get('metadata').get('title'), - 'description': item.get('metadata').get('description'), - 'source_url': item.get('metadata').get('sourceURL'), - 'markdown': item.get('markdown') + "title": item.get("metadata").get("title"), + "description": item.get("metadata").get("description"), + "source_url": item.get("metadata").get("sourceURL"), + "markdown": item.get("markdown"), } url_data_list.append(url_data) if url_data_list: - file_key = 'website_files/' + job_id + '.txt' + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + storage.save(file_key, json.dumps(url_data_list).encode("utf-8")) return { - 'status': 'completed', - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': url_data_list + "status": "completed", + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": url_data_list, } else: return { - 'status': crawl_status_response.get('status'), - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': [] + "status": crawl_status_response.get("status"), + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": [], } else: - self._handle_error(response, 'check crawl status') + self._handle_error(response, "check crawl status") def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.post(url, headers=headers, json=data) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response @@ -119,13 +109,11 @@ def _get_request(self, url, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.get(url, headers=headers) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response def _handle_error(self, response, action): - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') - - + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 8e2f107e5eb795..b33ce167c21c82 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -5,7 +5,7 @@ class FirecrawlWebExtractor(BaseExtractor): """ - Crawl and scrape websites and return content in clean llm-ready markdown. + Crawl and scrape websites and return content in clean llm-ready markdown. Args: @@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor): mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. """ - def __init__( - self, - url: str, - job_id: str, - tenant_id: str, - mode: str = 'crawl', - only_main_content: bool = False - ): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -33,28 +26,31 @@ def __init__( def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] - if self.mode == 'crawl': - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) if crawl_data is None: return [] - document = Document(page_content=crawl_data.get('markdown', ''), - metadata={ - 'source_url': crawl_data.get('source_url'), - 'description': crawl_data.get('description'), - 'title': crawl_data.get('title') - } - ) + document = Document( + page_content=crawl_data.get("markdown", ""), + metadata={ + "source_url": crawl_data.get("source_url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) documents.append(document) - elif self.mode == 'scrape': - scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, - self.only_main_content) + elif self.mode == "scrape": + scrape_data = WebsiteService.get_scrape_url_data( + "firecrawl", self._url, self.tenant_id, self.only_main_content + ) - document = Document(page_content=scrape_data.get('markdown', ''), - metadata={ - 'source_url': scrape_data.get('source_url'), - 'description': scrape_data.get('description'), - 'title': scrape_data.get('title') - } - ) + document = Document( + page_content=scrape_data.get("markdown", ""), + metadata={ + "source_url": scrape_data.get("source_url"), + "description": scrape_data.get("description"), + "title": scrape_data.get("title"), + }, + ) documents.append(document) return documents diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 0c17a47b329cab..9a21d4272a68e3 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -37,9 +37,7 @@ def read_and_detect(file_path: str) -> list[dict]: try: encodings = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - raise TimeoutError( - f"Timeout reached while detecting encoding for {file_path}" - ) + raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") if all(encoding["encoding"] is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb53062559a4d..560c2d1d84b04e 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor @@ -6,7 +7,6 @@ class HtmlExtractor(BaseExtractor): - """ Load html files. @@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str - ): + def __init__(self, file_path: str): """Initialize with file path.""" self._file_path = file_path @@ -27,8 +24,8 @@ def extract(self) -> list[Document]: def _load_as_text(self) -> str: with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') + soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() - text = text.strip() if text else '' + text = text.strip() if text else "" - return text \ No newline at end of file + return text diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index b24cf2e1707376..ca125ecf554d7e 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import re from typing import Optional, cast @@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = False, - remove_images: bool = False, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = False, + remove_images: bool = False, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -78,13 +79,10 @@ def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str] if current_header is not None: # pass linting, assert keys are defined markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] + markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] return markdown_tups diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7e839804c8c176..b02e30de623b77 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -21,22 +21,21 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" # if user want split by headings, use the corresponding splitter HEADING_SPLITTER = { - 'heading_1': '# ', - 'heading_2': '## ', - 'heading_3': '### ', + "heading_1": "# ", + "heading_2": "## ", + "heading_3": "### ", } -class NotionExtractor(BaseExtractor): +class NotionExtractor(BaseExtractor): def __init__( - self, - notion_workspace_id: str, - notion_obj_id: str, - notion_page_type: str, - tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - + self, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + tenant_id: str, + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model @@ -46,46 +45,38 @@ def __init__( if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, - self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." + "Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`." ) self._notion_access_token = integration_token def extract(self) -> list[Document]: - self.update_last_edited_time( - self._document_model - ) + self.update_last_edited_time(self._document_model) text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) return text_docs - def _load_data_as_documents( - self, notion_obj_id: str, notion_page_type: str - ) -> list[Document]: + def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]: docs = [] - if notion_page_type == 'database': + if notion_page_type == "database": # get all the pages in the database page_text_documents = self._get_notion_database_data(notion_obj_id) docs.extend(page_text_documents) - elif notion_page_type == 'page': + elif notion_page_type == "page": page_text_list = self._get_notion_block_data(notion_obj_id) - docs.append(Document(page_content='\n'.join(page_text_list))) + docs.append(Document(page_content="\n".join(page_text_list))) else: raise ValueError("notion page type not supported") return docs - def _get_notion_database_data( - self, database_id: str, query_dict: dict[str, Any] = {} - ) -> list[Document]: + def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -100,50 +91,50 @@ def _get_notion_database_data( data = res.json() database_content = [] - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: return [] for result in data["results"]: - properties = result['properties'] + properties = result["properties"] data = {} for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': + type = property_value["type"] + if type == "multi_select": value = [] multi_select_list = property_value[type] for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': + value.append(multi_select["name"]) + elif type == "rich_text" or type == "title": if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] + value = property_value[type][0]["plain_text"] else: - value = '' - elif type == 'select' or type == 'status': + value = "" + elif type == "select" or type == "status": if property_value[type]: - value = property_value[type]['name'] + value = property_value[type]["name"] else: - value = '' + value = "" else: value = property_value[type] data[property_name] = value row_dict = {k: v for k, v in data.items() if v} - row_content = '' + row_content = "" for key, value in row_dict.items(): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} - value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items()) - row_content = row_content + f'{key}:{value_content}\n' + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - row_content = row_content + f'{key}:{value}\n' + row_content = row_content + f"{key}:{value}\n" database_content.append(row_content) - return [Document(page_content='\n'.join(database_content))] + return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", block_url, @@ -152,14 +143,14 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) text += "\n\n" @@ -175,17 +166,15 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -199,7 +188,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -209,16 +198,16 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: break for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) result_lines_arr.append(text) @@ -233,17 +222,15 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: - result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -260,7 +247,7 @@ def _read_table_rows(self, block_id: str) -> str: start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while not done: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -270,28 +257,28 @@ def _read_table_rows(self, block_id: str) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() # get table headers text table_header_cell_texts = [] - table_header_cells = data["results"][0]['table_row']['cells'] + table_header_cells = data["results"][0]["table_row"]["cells"] for table_header_cell in table_header_cells: if table_header_cell: for table_header_cell_text in table_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) else: - table_header_cell_texts.append('') + table_header_cell_texts.append("") # Initialize Markdown table with headers markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" - markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n" # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - table_column_cells = data["results"][i + 1]['table_row']['cells'] + table_column_cells = data["results"][i + 1]["table_row"]["cells"] for j in range(len(table_column_cells)): if table_column_cells[j]: for table_column_cell_text in table_column_cells[j]: @@ -315,10 +302,8 @@ def update_last_edited_time(self, document_model: DocumentModel): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - DocumentModel.data_source_info: json.dumps(data_source_info) - } + data_source_info["last_edited_time"] = last_edited_time + update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.commit() @@ -326,7 +311,7 @@ def update_last_edited_time(self, document_model: DocumentModel): def get_notion_last_edited_time(self) -> str: obj_id = self._notion_obj_id page_type = self._notion_page_type - if page_type == 'database': + if page_type == "database": retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) @@ -341,7 +326,7 @@ def get_notion_last_edited_time(self) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - json=query_dict + json=query_dict, ) data = res.json() @@ -352,14 +337,16 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) ).first() if not data_source_binding: - raise Exception(f'No notion data source binding found for tenant {tenant_id} ' - f'and notion workspace {notion_workspace_id}') + raise Exception( + f"No notion data source binding found for tenant {tenant_id} " + f"and notion workspace {notion_workspace_id}" + ) return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 0864fec6c8bcd1..57cb9610ba267e 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from collections.abc import Iterator from typing import Optional @@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - file_cache_key: Optional[str] = None - ): + def __init__(self, file_path: str, file_cache_key: Optional[str] = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = '' + plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode('utf-8') + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -43,12 +40,12 @@ def extract(self) -> list[Document]: # save plaintext file for caching if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) + storage.save(plaintext_file_key, text.encode("utf-8")) return documents def load( - self, + self, ) -> Iterator[Document]: """Lazy load given path as pages.""" blob = Blob.from_path(self._file_path) diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ac5d0920cffb48..ed0ae41f51e390 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 0323b14a4a34fd..a525c9e9e3c443 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -8,13 +8,12 @@ class UnstructuredWordExtractor(BaseExtractor): - """Loader that uses unstructured to load word documents. - """ + """Loader that uses unstructured to load word documents.""" def __init__( - self, - file_path: str, - api_url: str, + self, + file_path: str, + api_url: str, ): """Initialize with file path.""" self._file_path = file_path @@ -24,9 +23,7 @@ def extract(self) -> list[Document]: from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.file_utils.filetype import FileType, detect_filetype - unstructured_version = tuple( - int(x) for x in __unstructured_version__.split(".") - ) + unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: import magic # noqa: F401 @@ -53,6 +50,7 @@ def extract(self) -> list[Document]: elements = partition_docx(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2e704f187d05d6..34c6811b671460 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -26,6 +26,7 @@ def __init__( def extract(self) -> list[Document]: from unstructured.partition.email import partition_email + elements = partition_email(filename=self._file_path) # noinspection PyBroadException @@ -34,15 +35,16 @@ def extract(self) -> list[Document]: element_text = element.text.strip() padding_needed = 4 - len(element_text) % 4 - element_text += '=' * padding_needed + element_text += "=" * padding_needed element_decode = base64.b64decode(element_text) - soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser') + soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() except Exception: pass from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 44cf958ea2b636..fa50fa76b24980 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -28,6 +28,7 @@ def extract(self) -> list[Document]: elements = partition_epub(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 144b4e0c1d7a91..fc3ff1069337fa 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -38,6 +38,7 @@ def extract(self) -> list[Document]: elements = partition_md(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index ad09b79eb00a07..8091e83e851252 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ def extract(self) -> list[Document]: elements = partition_msg(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index d354b593ed7af6..b69394b3b1db68 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - api_key: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 6fcbb5feb991d0..6ed4a0dfb3881e 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py index f4a4adbc1600fd..22dfdd20752cbf 100644 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ def extract(self) -> list[Document]: elements = partition_text(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 6aef8e0f7e2718..3bffc01fbf3c5b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ def extract(self) -> list[Document]: elements = partition_xml(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 15822867bbe5e4..2db00d161b366e 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import datetime import logging import mimetypes @@ -21,6 +22,7 @@ logger = logging.getLogger(__name__) + class WordExtractor(BaseExtractor): """Load docx files. @@ -43,9 +45,7 @@ def __init__(self, file_path: str, tenant_id: str, user_id: str): r = requests.get(self.file_path) if r.status_code != 200: - raise ValueError( - f"Check the url of your file; returned status code {r.status_code}" - ) + raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path self.temp_file = tempfile.NamedTemporaryFile() @@ -60,11 +60,13 @@ def __del__(self) -> None: def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, 'storage') - return [Document( - page_content=content, - metadata={"source": self.file_path}, - )] + content = self.parse_docx(self.file_path, "storage") + return [ + Document( + page_content=content, + metadata={"source": self.file_path}, + ) + ] @staticmethod def _is_valid_url(url: str) -> bool: @@ -84,18 +86,18 @@ def _extract_images_from_docx(self, doc, image_folder): url = rel.reltype response = requests.get(url, stream=True) if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers['Content-Type']) + image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) else: continue else: - image_ext = rel.target_ref.split('.')[-1] + image_ext = rel.target_ref.split(".")[-1] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) @@ -112,12 +114,14 @@ def _extract_images_from_docx(self, doc, image_folder): created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = ( + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + ) return image_map @@ -167,8 +171,8 @@ def _parse_cell(self, cell, image_map): def _parse_cell_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue @@ -184,16 +188,16 @@ def _parse_cell_paragraph(self, paragraph, image_map): def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): - embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if embed_id: rel_target = run.part.rels[embed_id].target_ref if rel_target in image_map: paragraph_content.append(image_map[rel_target]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ' '.join(paragraph_content) if paragraph_content else '' + return " ".join(paragraph_content) if paragraph_content else "" def parse_docx(self, docx_path, image_folder): doc = DocxDocument(docx_path) @@ -204,60 +208,59 @@ def parse_docx(self, docx_path, image_folder): image_map = self._extract_images_from_docx(doc, image_folder) hyperlinks_url = None - url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") for para in doc.paragraphs: for run in para.runs: if run.text and hyperlinks_url: - result = f' [{run.text}]({hyperlinks_url}) ' + result = f" [{run.text}]({hyperlinks_url}) " run.text = result hyperlinks_url = None - if 'HYPERLINK' in run.element.xml: + if "HYPERLINK" in run.element.xml: try: xml = ET.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: if x_child is None: continue - if x.tag.endswith('instrText'): + if x.tag.endswith("instrText"): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: logger.error(e) - - - def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): + if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( - './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" + ) for drawing in drawing_elements: blip_elements = drawing.findall( - './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + ) for blip in blip_elements: embed_id = blip.get( - '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ''.join(paragraph_content) if paragraph_content else '' + return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if hasattr(element, 'tag'): - if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + if hasattr(element, "tag"): + if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) if parsed_paragraph: content.append(parsed_paragraph) - elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) - return '\n'.join(content) - + return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 630387fe3a8e3d..be857bd12215fd 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod from typing import Optional @@ -15,8 +16,7 @@ class BaseIndexProcessor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -34,18 +34,24 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: raise NotImplementedError @abstractmethod - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule['mode'] == "custom": + if processing_rule["mode"] == "custom": # The user-defined segmentation rule - rules = processing_rule['rules'] + rules = processing_rule["rules"] segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: @@ -53,22 +59,22 @@ def _get_splitter(self, processing_rule: dict, separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, + chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index df43a6491074b8..9b855ece2c3512 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -7,8 +7,7 @@ class IndexProcessorFactory: - """IndexProcessorInit. - """ + """IndexProcessorInit.""" def __init__(self, index_type: str): self._index_type = index_type @@ -22,7 +21,6 @@ def init_index_processor(self) -> BaseIndexProcessor: if self._index_type == IndexType.PARAGRAPH_INDEX.value: return ParagraphIndexProcessor() elif self._index_type == IndexType.QA_INDEX.value: - return QAIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index bd7f6093bd1dd0..ed5712220f072e 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import uuid from typing import Optional @@ -15,33 +16,32 @@ class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -55,7 +55,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: return all_documents def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: @@ -63,7 +63,7 @@ def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -76,17 +76,29 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: else: keyword.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index a44fd9803682f7..1dbc473281daf6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import logging import re import threading @@ -23,33 +24,33 @@ class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) # Split the text documents into nodes. all_documents = [] all_qa_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -61,14 +62,18 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': kwargs.get('tenant_id'), - 'document_node': doc, - 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('doc_language', 'English')}) + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -76,9 +81,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: - # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -86,7 +90,7 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={'answer': row[1]}) + data = Document(page_content=row[0], metadata={"answer": row[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,7 +100,7 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: return text_docs def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -107,17 +111,29 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: else: vector.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict): + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ): # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) @@ -134,12 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a document_qa_list = self._format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -151,10 +167,4 @@ def _format_split_text(self, text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 6f3c1c5d343977..0ff1fdb81cb870 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -55,9 +55,7 @@ async def atransform_documents( """ @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -68,9 +66,7 @@ def transform_documents( """ @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py index afbb9fd89d406d..d4894e3cc6731f 100644 --- a/api/core/rag/rerank/constants/rerank_mode.py +++ b/api/core/rag/rerank/constants/rerank_mode.py @@ -2,7 +2,5 @@ class RerankMode(Enum): - - RERANKING_MODEL = 'reranking_model' - WEIGHTED_SCORE = 'weighted_score' - + RERANKING_MODEL = "reranking_model" + WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index d9067da2880fec..6356ff87ab8ef1 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -8,8 +8,14 @@ class RerankModelRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -23,19 +29,15 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) documents = unique_documents rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) rerank_documents = [] @@ -45,12 +47,12 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f rerank_document = Document( page_content=result.text, metadata={ - "doc_id": documents[result.index].metadata['doc_id'], - "doc_hash": documents[result.index].metadata['doc_hash'], - "document_id": documents[result.index].metadata['document_id'], - "dataset_id": documents[result.index].metadata['dataset_id'], - 'score': result.score - } + "doc_id": documents[result.index].metadata["doc_id"], + "doc_hash": documents[result.index].metadata["doc_hash"], + "document_id": documents[result.index].metadata["document_id"], + "dataset_id": documents[result.index].metadata["dataset_id"], + "score": result.score, + }, ) rerank_documents.append(rerank_document) diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d8a78739826a31..4375079ee52775 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -13,13 +13,18 @@ class WeightRerankRunner: - def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -34,8 +39,8 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -47,13 +52,15 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): # format document - score = self.weights.vector_setting.vector_weight * query_vector_score + \ - self.weights.keyword_setting.keyword_weight * query_score + score = ( + self.weights.vector_setting.vector_weight * query_vector_score + + self.weights.keyword_setting.keyword_weight * query_score + ) if score_threshold and score < score_threshold: continue - document.metadata['score'] = score + document.metadata["score"] = score rerank_documents.append(document) - rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -70,7 +77,7 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -132,8 +139,9 @@ def cosine_similarity(vec1, vec2): return similarities - def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], - vector_setting: VectorSetting) -> list[float]: + def _calculate_cosine( + self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting + ) -> list[float]: """ Calculate Cosine scores :param query: search query @@ -149,15 +157,14 @@ def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document tenant_id=tenant_id, provider=vector_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=vector_setting.embedding_model_name - + model=vector_setting.embedding_model_name, ) cache_embedding = CacheEmbedding(embedding_model) query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if 'score' in document.metadata: - query_vector_scores.append(document.metadata['score']) + if "score" in document.metadata: + query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy vec1 = np.array(query_vector) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index db01652f896b05..4948ec6ba87c20 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -32,14 +32,11 @@ from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -48,15 +45,18 @@ def __init__(self, application_generate_entity=None): self.application_generate_entity = application_generate_entity def retrieve( - self, app_id: str, user_id: str, tenant_id: str, - model_config: ModelConfigWithCredentialsEntity, - config: DatasetEntity, - query: str, - invoke_from: InvokeFrom, - show_retrieve_source: bool, - hit_callback: DatasetIndexToolCallbackHandler, - message_id: str, - memory: Optional[TokenBufferMemory] = None, + self, + app_id: str, + user_id: str, + tenant_id: str, + model_config: ModelConfigWithCredentialsEntity, + config: DatasetEntity, + query: str, + invoke_from: InvokeFrom, + show_retrieve_source: bool, + hit_callback: DatasetIndexToolCallbackHandler, + message_id: str, + memory: Optional[TokenBufferMemory] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -84,16 +84,12 @@ def retrieve( model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if not model_schema: @@ -102,39 +98,46 @@ def retrieve( planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) all_documents = [] - user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( - app_id, tenant_id, user_id, user_from, available_datasets, query, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, model_instance, - model_config, planning_strategy, message_id + model_config, + planning_strategy, + message_id, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( - app_id, tenant_id, user_id, user_from, - available_datasets, query, retrieve_config.top_k, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, + retrieve_config.top_k, retrieve_config.score_threshold, retrieve_config.rerank_mode, retrieve_config.reranking_model, @@ -145,89 +148,89 @@ def retrieve( document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if show_retrieve_source: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ).first() - document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': invoke_from.to_source(), - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, None), } - if invoke_from.to_source() == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 if hit_callback: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) - return '' + return "" def single_retrieve( - self, app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - model_instance: ModelInstance, - model_config: ModelConfigWithCredentialsEntity, - planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + message_id: Optional[str] = None, ): tools = [] for dataset in available_datasets: description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") message_tool = PromptMessageTool( name=dataset.id, description=description, @@ -235,14 +238,15 @@ def single_retrieve( "type": "object", "properties": {}, "required": [], - } + }, ) tools.append(message_tool) dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, - user_id, tenant_id) + dataset_id = react_multi_dataset_router.invoke( + query, tools, model_config, model_instance, user_id, tenant_id + ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() @@ -250,37 +254,37 @@ def single_retrieve( if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get retrieval method if dataset.indexing_technique == "economy": - retrieval_method = 'keyword_search' + retrieval_method = "keyword_search" else: - retrieval_method = retrieval_model_config['search_method'] + retrieval_method = retrieval_model_config["search_method"] # get reranking model - reranking_model = retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None + reranking_model = ( + retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + ) # get score threshold - score_threshold = .0 + score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") with measure_time() as timer: results = RetrievalService.retrieve( - retrieval_method=retrieval_method, dataset_id=dataset.id, + retrieval_method=retrieval_method, + dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, + top_k=top_k, + score_threshold=score_threshold, reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), - weights=retrieval_model_config.get('weights', None), + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -291,20 +295,20 @@ def single_retrieve( return [] def multiple_retrieve( - self, - app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - top_k: int, - score_threshold: float, - reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - reranking_enable: bool = True, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, + message_id: Optional[str] = None, ): threads = [] all_documents = [] @@ -312,13 +316,16 @@ def multiple_retrieve( index_type = None for dataset in available_datasets: index_type = dataset.indexing_technique - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -327,16 +334,10 @@ def multiple_retrieve( with measure_time() as timer: if reranking_enable: # do rerank for searched documents - data_post_processor = DataPostProcessor( - tenant_id, reranking_mode, - reranking_model, weights, False - ) + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) else: if index_type == "economy": @@ -357,30 +358,26 @@ def _on_retrieval_end( """Handle retrieval end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, - message_id=message_id, - documents=documents, - timer=timer + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) ) @@ -395,10 +392,10 @@ def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=app_id, created_by_role=user_from, - created_by=user_id + created_by=user_id, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -407,9 +404,7 @@ def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] @@ -419,38 +414,42 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) if documents: all_documents.extend(documents) else: if top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) all_documents.extend(documents) - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[DatasetRetrieverBaseTool]]: + def to_dataset_retriever_tool( + self, + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -464,18 +463,14 @@ def to_dataset_retriever_tool(self, tenant_id: str, available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) @@ -483,22 +478,18 @@ def to_dataset_retriever_tool(self, tenant_id: str, if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get score threshold score_threshold = None @@ -512,7 +503,7 @@ def to_dataset_retriever_tool(self, tenant_id: str, score_threshold=score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, - retriever_from=invoke_from.to_source() + retriever_from=invoke_from.to_source(), ) tools.append(tool) @@ -525,8 +516,8 @@ def to_dataset_retriever_tool(self, tenant_id: str, hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), ) tools.append(tool) @@ -547,7 +538,7 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -606,21 +597,19 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata['score'] = score - documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents[:top_k] if top_k else documents - def calculate_vector_score(self, all_documents: list[Document], - top_k: int, score_threshold: float) -> list[Document]: + def calculate_vector_score( + self, all_documents: list[Document], top_k: int, score_threshold: float + ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata["score"] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) return filter_documents[:top_k] if top_k else filter_documents - - - diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index 60770bd4c6e06a..7fc78bce8357da 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -16,9 +16,7 @@ def parse(self, text: str) -> Union[ReactAction, ReactFinish]: if response["action"] == "Final Answer": return ReactFinish({"output": response["action_input"]}, text) else: - return ReactAction( - response["action"], response.get("action_input", {}), text - ) + return ReactAction(response["action"], response.get("action_input", {}), text) else: return ReactFinish({"output": text}, text) except Exception as e: diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index 12aa28a51c1d98..eaa00bca884a7c 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -2,9 +2,9 @@ class RetrievalMethod(Enum): - SEMANTIC_SEARCH = 'semantic_search' - FULL_TEXT_SEARCH = 'full_text_search' - HYBRID_SEARCH = 'hybrid_search' + SEMANTIC_SEARCH = "semantic_search" + FULL_TEXT_SEARCH = "full_text_search" + HYBRID_SEARCH = "hybrid_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 84e53952acbf12..06147fe7b56544 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -6,14 +6,12 @@ class FunctionCallMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -26,22 +24,18 @@ def invoke( try: prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage(content=query), ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, tools=dataset_tools, stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) if result.message.tool_calls: # get retrieval model config return result.message.tool_calls[0].function.name return None except Exception as e: - return None \ No newline at end of file + return None diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 92f24277c1a3cc..33841cac06adbf 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -50,16 +50,14 @@ class ReactMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - user_id: str, - tenant_id: str - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + user_id: str, + tenant_id: str, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -71,23 +69,28 @@ def invoke( return dataset_tools[0].name try: - return self._react_invoke(query=query, model_config=model_config, - model_instance=model_instance, - tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) + return self._react_invoke( + query=query, + model_config=model_config, + model_instance=model_instance, + tools=dataset_tools, + user_id=user_id, + tenant_id=tenant_id, + ) except Exception as e: return None def _react_invoke( - self, - query: str, - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - tools: Sequence[PromptMessageTool], - user_id: str, - tenant_id: str, - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + tools: Sequence[PromptMessageTool], + user_id: str, + tenant_id: str, + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -103,18 +106,18 @@ def _react_invoke( prefix=prefix, format_instructions=format_instructions, ) - stop = ['Observation:'] + stop = ["Observation:"] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=None, memory=None, - model_config=model_config + model_config=model_config, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -122,7 +125,7 @@ def _react_invoke( prompt_messages=prompt_messages, stop=stop, user_id=user_id, - tenant_id=tenant_id + tenant_id=tenant_id, ) output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) @@ -130,17 +133,21 @@ def _react_invoke( return react_decision.tool return None - def _invoke_llm(self, completion_param: dict, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str - ) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + completion_param: dict, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str], + user_id: str, + tenant_id: str, + ) -> tuple[str, LLMUsage]: """ - Invoke large language model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: + Invoke large language model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -151,9 +158,7 @@ def _invoke_llm(self, completion_param: dict, ) # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) + text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) @@ -168,7 +173,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage """ model = None prompt_messages = [] - full_text = '' + full_text = "" usage = None for result in invoke_result: text = result.delta.message.content @@ -189,40 +194,35 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage return full_text, usage def create_chat_prompt( - self, - query: str, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: tool_strings.append( - f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}" + ) formatted_tools = "\n".join(tool_strings) unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) prompt_messages = [] - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=template - ) + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template) prompt_messages.append(system_prompt_messages) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=query - ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query) prompt_messages.append(user_prompt_message) return prompt_messages def create_completion_prompt( - self, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0c1cb57c7f4e03..53032b34d570c7 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,4 +1,5 @@ """Functionality for splitting text.""" + from __future__ import annotations from typing import Any, Optional @@ -18,31 +19,29 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ - This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( - cls: type[TS], - embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal[all], Set[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens( - texts=[text] - ) + return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', + "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index f06f22a00e1855..97d0721304405a 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -22,9 +22,7 @@ TS = TypeVar("TS", bound="TextSplitter") -def _split_text_with_regex( - text: str, separator: str, keep_separator: bool -) -> list[str]: +def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: @@ -37,19 +35,19 @@ def _split_text_with_regex( splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != '\n')] + return [s for s in splits if (s != "" and s != "\n")] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( - self, - chunk_size: int = 4000, - chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, - keep_separator: bool = False, - add_start_index: bool = False, + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, ) -> None: """Create a new TextSplitter. @@ -62,8 +60,7 @@ def __init__( """ if chunk_overlap > chunk_size: raise ValueError( - f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " - f"({chunk_size}), should be smaller." + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap @@ -75,9 +72,7 @@ def __init__( def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents( - self, texts: list[str], metadatas: Optional[list[dict]] = None - ) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -119,14 +114,10 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int index = 0 for d in splits: _len = lengths[index] - if ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - ): + if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, " - f"which is longer than the specified {self._chunk_size}" + f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) @@ -136,13 +127,9 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + ( - separator_len if len(current_doc) > 1 else 0 - ) + total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) @@ -159,28 +146,25 @@ def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitt from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - "Tokenizer received was not an instance of PreTrainedTokenizerBase" - ) + raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( - "Could not import transformers python package. " - "Please install it with `pip install transformers`." + "Could not import transformers python package. " "Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -217,15 +201,11 @@ def _tiktoken_encoder(text: str) -> int: return cls(length_function=_tiktoken_encoder, **kwargs) - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -267,9 +247,7 @@ class HeaderType(TypedDict): class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" - def __init__( - self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False - ): + def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): """Create a new MarkdownHeaderTextSplitter. Args: @@ -280,9 +258,7 @@ def __init__( self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted( - headers_to_split_on, key=lambda split: len(split[0]), reverse=True - ) + self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: """Combine lines with common metadata into chunks @@ -292,10 +268,7 @@ def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: aggregated_chunks: list[LineType] = [] for line in lines: - if ( - aggregated_chunks - and aggregated_chunks[-1]["metadata"] == line["metadata"] - ): + if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content @@ -304,10 +277,7 @@ def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in aggregated_chunks - ] + return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] def split_text(self, text: str) -> list[Document]: """Split markdown file @@ -332,10 +302,9 @@ def split_text(self, text: str) -> list[Document]: for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: @@ -343,10 +312,7 @@ def split_text(self, text: str) -> list[Document]: current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack - while ( - header_stack - and header_stack[-1]["level"] >= current_header_level - ): + while header_stack and header_stack[-1]["level"] >= current_header_level: # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() @@ -359,7 +325,7 @@ def split_text(self, text: str) -> list[Document]: header: HeaderType = { "level": current_header_level, "name": name, - "data": stripped_line[len(sep):].strip(), + "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header @@ -392,9 +358,7 @@ def split_text(self, text: str) -> list[Document]: current_metadata = initial_metadata.copy() if current_content: - lines_with_metadata.append( - {"content": "\n".join(current_content), "metadata": current_metadata} - ) + lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata @@ -402,8 +366,7 @@ def split_text(self, text: str) -> list[Document]: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata + Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata ] @@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( - self, - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter): """ def __init__( - self, - separators: Optional[list[str]] = None, - keep_separator: bool = True, - **kwargs: Any, + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) @@ -508,7 +471,7 @@ def _split_text(self, text: str, separators: list[str]) -> list[str]: break if re.search(_s, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2b01b8fd8e89c9..b988a588e907df 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -10,23 +10,23 @@ class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None labels: list[str] = None -UserToolProviderTypeLiteral = Optional[Literal[ - 'builtin', 'api', 'workflow' -]] + +UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] + class UserToolProvider(BaseModel): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str - label: I18nObject # label + label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None @@ -40,26 +40,27 @@ def to_dict(self) -> dict: # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) for tool in tools: - if tool.get('parameters'): - for parameter in tool.get('parameters'): - if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: - parameter['type'] = 'files' + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + parameter["type"] = "files" # ------------- return { - 'id': self.id, - 'author': self.author, - 'name': self.name, - 'description': self.description.to_dict(), - 'icon': self.icon, - 'label': self.label.to_dict(), - 'type': self.type.value, - 'team_credentials': self.masked_credentials, - 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete, - 'tools': tools, - 'labels': self.labels, + "id": self.id, + "author": self.author, + "name": self.name, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, } + class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] \ No newline at end of file + credentials: dict[str, ToolProviderCredentials] diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 55e31e8c35e5b7..37a926697bb432 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None pt_BR: Optional[str] = None en_US: str @@ -19,8 +20,4 @@ def __init__(self, **data): self.pt_BR = self.en_US def to_dict(self) -> dict: - return { - 'zh_Hans': self.zh_Hans, - 'en_US': self.en_US, - 'pt_BR': self.pt_BR - } + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index d18d27fb02beee..da6201c5aaecd2 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -9,6 +9,7 @@ class ApiToolBundle(BaseModel): """ This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. """ + # server_url server_url: str # method diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index e31dec55d21898..02b8b35be7d628 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,27 +7,29 @@ class ToolLabelEnum(Enum): - SEARCH = 'search' - IMAGE = 'image' - VIDEOS = 'videos' - WEATHER = 'weather' - FINANCE = 'finance' - DESIGN = 'design' - TRAVEL = 'travel' - SOCIAL = 'social' - NEWS = 'news' - MEDICAL = 'medical' - PRODUCTIVITY = 'productivity' - EDUCATION = 'education' - BUSINESS = 'business' - ENTERTAINMENT = 'entertainment' - UTILITIES = 'utilities' - OTHER = 'other' + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + class ToolProviderType(Enum): """ - Enum class for tool provider + Enum class for tool provider """ + BUILT_IN = "builtin" WORKFLOW = "workflow" API = "api" @@ -35,7 +37,7 @@ class ToolProviderType(Enum): DATASET_RETRIEVAL = "dataset-retrieval" @classmethod - def value_of(cls, value: str) -> 'ToolProviderType': + def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. @@ -45,19 +47,21 @@ def value_of(cls, value: str) -> 'ToolProviderType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. """ + OPENAPI = "openapi" SWAGGER = "swagger" OPENAI_PLUGIN = "openai_plugin" OPENAI_ACTIONS = "openai_actions" @classmethod - def value_of(cls, value: str) -> 'ApiProviderSchemaType': + def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. @@ -67,17 +71,19 @@ def value_of(cls, value: str) -> 'ApiProviderSchemaType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. """ + NONE = "none" API_KEY = "api_key" @classmethod - def value_of(cls, value: str) -> 'ApiProviderAuthType': + def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. @@ -87,7 +93,8 @@ def value_of(cls, value: str) -> 'ApiProviderAuthType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ToolInvokeMessage(BaseModel): class MessageType(Enum): @@ -105,19 +112,21 @@ class MessageType(Enum): """ message: str | bytes | dict | None = None meta: dict[str, Any] | None = None - save_as: str = '' + save_as: str = "" + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = '' + save_as: str = "" file_var: Optional[dict[str, Any]] = None + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - @field_validator('value', mode='before') + @field_validator("value", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -136,9 +145,9 @@ class ToolParameterType(str, Enum): FILE = "file" class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") @@ -154,25 +163,32 @@ class ToolParameterForm(Enum): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, - required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': + def get_simple_instance( + cls, + name: str, + llm_description: str, + type: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": """ - get a simple tool parameter + get a simple tool parameter - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter """ # convert options to ToolParameterOption if options: - options = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + options = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] return cls( name=name, - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, @@ -180,18 +196,24 @@ def get_simple_instance(cls, options=options, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -199,10 +221,12 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None + class ToolCredentialsOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolProviderCredentials(BaseModel): class CredentialsType(Enum): SECRET_INPUT = "secret-input" @@ -221,7 +245,7 @@ def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") @staticmethod def default(value: str) -> str: @@ -239,33 +263,38 @@ def default(value: str) -> str: def to_dict(self) -> dict: return { - 'name': self.name, - 'type': self.type.value, - 'required': self.required, - 'default': self.default, - 'options': self.options, - 'help': self.help.to_dict() if self.help else None, - 'label': self.label.to_dict(), - 'url': self.url, - 'placeholder': self.placeholder.to_dict() if self.placeholder else None, + "name": self.name, + "type": self.type.value, + "required": self.required, + "default": self.default, + "options": self.options, + "help": self.help.to_dict() if self.help else None, + "label": self.label.to_dict(), + "url": self.url, + "placeholder": self.placeholder.to_dict() if self.placeholder else None, } + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -274,26 +303,26 @@ class ToolRuntimeVariablePool(BaseModel): pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") def __init__(self, **data: Any): - pool = data.get('pool', []) + pool = data.get("pool", []) # convert pool into correct type for index, variable in enumerate(pool): - if variable['type'] == ToolRuntimeVariableType.TEXT.value: + if variable["type"] == ToolRuntimeVariableType.TEXT.value: pool[index] = ToolRuntimeTextVariable(**variable) - elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) def dict(self) -> dict: return { - 'conversation_id': self.conversation_id, - 'user_id': self.user_id, - 'tenant_id': self.tenant_id, - 'pool': [variable.model_dump() for variable in self.pool], + "conversation_id": self.conversation_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "pool": [variable.model_dump() for variable in self.pool], } def set_text(self, tool_name: str, name: str, value: str) -> None: """ - set a text variable + set a text variable """ for variable in self.pool: if variable.name == name: @@ -314,10 +343,10 @@ def set_text(self, tool_name: str, name: str, value: str) -> None: def set_file(self, tool_name: str, value: str, name: str = None) -> None: """ - set an image variable + set an image variable - :param tool_name: the name of the tool - :param value: the id of the file + :param tool_name: the name of the tool + :param value: the id of the file """ # check how many image variables are there image_variable_count = 0 @@ -345,22 +374,27 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration """ + type: str = Field(..., description="The type of the model tool") model: str = Field(..., description="The model") label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ + provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") @@ -370,27 +404,30 @@ class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration """ + name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta """ + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @classmethod - def empty(cls) -> 'ToolInvokeMeta': + def empty(cls) -> "ToolInvokeMeta": """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> 'ToolInvokeMeta': + def error_instance(cls, error: str) -> "ToolInvokeMeta": """ Get an instance of ToolInvokeMeta with error """ @@ -398,22 +435,26 @@ def error_instance(cls, error: str) -> 'ToolInvokeMeta': def to_dict(self) -> dict: return { - 'time_cost': self.time_cost, - 'error': self.error, - 'tool_config': self.tool_config, + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, } + class ToolLabel(BaseModel): """ Tool label """ + name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") icon: str = Field(..., description="The icon of the tool") + class ToolInvokeFrom(Enum): """ Enum class for tool invoke """ + WORKFLOW = "workflow" AGENT = "agent" diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index d0be5e93557fe6..f9db190f91e682 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -2,73 +2,109 @@ from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum ICONS = { - ToolLabelEnum.SEARCH: ''' + ToolLabelEnum.SEARCH: """ -''', - ToolLabelEnum.IMAGE: ''' +""", + ToolLabelEnum.IMAGE: """ -''', - ToolLabelEnum.VIDEOS: ''' +""", + ToolLabelEnum.VIDEOS: """ -''', - ToolLabelEnum.WEATHER: ''' +""", + ToolLabelEnum.WEATHER: """ -''', - ToolLabelEnum.FINANCE: ''' +""", + ToolLabelEnum.FINANCE: """ -''', - ToolLabelEnum.DESIGN: ''' +""", + ToolLabelEnum.DESIGN: """ -''', - ToolLabelEnum.TRAVEL: ''' +""", + ToolLabelEnum.TRAVEL: """ -''', - ToolLabelEnum.SOCIAL: ''' +""", + ToolLabelEnum.SOCIAL: """ -''', - ToolLabelEnum.NEWS: ''' +""", + ToolLabelEnum.NEWS: """ -''', - ToolLabelEnum.MEDICAL: ''' +""", + ToolLabelEnum.MEDICAL: """ -''', - ToolLabelEnum.PRODUCTIVITY: ''' +""", + ToolLabelEnum.PRODUCTIVITY: """ -''', - ToolLabelEnum.EDUCATION: ''' +""", + ToolLabelEnum.EDUCATION: """ -''', - ToolLabelEnum.BUSINESS: ''' +""", + ToolLabelEnum.BUSINESS: """ -''', - ToolLabelEnum.ENTERTAINMENT: ''' +""", + ToolLabelEnum.ENTERTAINMENT: """ -''', - ToolLabelEnum.UTILITIES: ''' +""", + ToolLabelEnum.UTILITIES: """ -''', - ToolLabelEnum.OTHER: ''' +""", + ToolLabelEnum.OTHER: """ -''' +""", } default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), - ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), - ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), - ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), - ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), - ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), - ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), - ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), - ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), - ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), - ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), - ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), - ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), - ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), - ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), - ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), } default_tool_labels = [v for k, v in default_tool_label_dict.items()] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 9fd8322db13741..6febf137b000f9 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -4,23 +4,30 @@ class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): pass + class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta \ No newline at end of file + meta: ToolInvokeMeta diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index ae80ad2114cce0..2e6018cffc1c4a 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,4 +1,3 @@ - from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -18,85 +17,69 @@ class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { - 'auth_type': ToolProviderCredentials( - name='auth_type', + "auth_type": ToolProviderCredentials( + name="auth_type", required=True, type=ToolProviderCredentials.CredentialsType.SELECT, options=[ - ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), ], - default='none', - help=I18nObject( - en_US='The auth type of the api provider', - zh_Hans='api provider 的认证类型' - ) + default="none", + help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) } if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ToolProviderCredentials( - name='api_key_header', + "api_key_header": ToolProviderCredentials( + name="api_key_header", required=False, - default='api_key', + default="api_key", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, - help=I18nObject( - en_US='The header name of the api key', - zh_Hans='携带 api key 的 header 名称' - ) + help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - 'api_key_value': ToolProviderCredentials( - name='api_key_value', + "api_key_value": ToolProviderCredentials( + name="api_key_value", required=True, type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, - help=I18nObject( - en_US='The api key', - zh_Hans='api key的值' - ) + help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - 'api_key_header_prefix': ToolProviderCredentials( - name='api_key_header_prefix', + "api_key_header_prefix": ToolProviderCredentials( + name="api_key_header_prefix", required=False, - default='basic', + default="basic", type=ToolProviderCredentials.CredentialsType.SELECT, - help=I18nObject( - en_US='The prefix of the api key header', - zh_Hans='api key header 的前缀' - ), + help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"), options=[ - ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) - ] - ) + ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")), + ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")), + ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), + ], + ), } elif auth_type == ApiProviderAuthType.NONE: pass else: - raise ValueError(f'invalid auth type {auth_type}') - - user_name = db_provider.user.name if db_provider.user_id else '' - - return ApiToolProviderController(**{ - 'identity': { - 'author': user_name, - 'name': db_provider.name, - 'label': { - 'en_US': db_provider.name, - 'zh_Hans': db_provider.name - }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description + raise ValueError(f"invalid auth type {auth_type}") + + user_name = db_provider.user.name if db_provider.user_id else "" + + return ApiToolProviderController( + **{ + "identity": { + "author": user_name, + "name": db_provider.name, + "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'icon': db_provider.icon, - }, - 'credentials_schema': credentials_schema, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": credentials_schema, + "provider_id": db_provider.id or "", + } + ) @property def provider_type(self) -> ToolProviderType: @@ -104,39 +87,35 @@ def provider_type(self) -> ToolProviderType: def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ - parse tool bundle to tool + parse tool bundle to tool - :param tool_bundle: the tool bundle - :return: the tool + :param tool_bundle: the tool bundle + :return: the tool """ - return ApiTool(**{ - 'api_bundle': tool_bundle, - 'identity' : { - 'author': tool_bundle.author, - 'name': tool_bundle.operation_id, - 'label': { - 'en_US': tool_bundle.operation_id, - 'zh_Hans': tool_bundle.operation_id + return ApiTool( + **{ + "api_bundle": tool_bundle, + "identity": { + "author": tool_bundle.author, + "name": tool_bundle.operation_id, + "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, + "icon": self.identity.icon, + "provider": self.provider_id, }, - 'icon': self.identity.icon, - 'provider': self.provider_id, - }, - 'description': { - 'human': { - 'en_US': tool_bundle.summary or '', - 'zh_Hans': tool_bundle.summary or '' + "description": { + "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, + "llm": tool_bundle.summary or "", }, - 'llm': tool_bundle.summary or '' - }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], - }) + "parameters": tool_bundle.parameters if tool_bundle.parameters else [], + } + ) def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ - load bundled tools + load bundled tools - :param tools: the bundled tools - :return: the tools + :param tools: the bundled tools + :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] @@ -144,22 +123,23 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - + tools: list[Tool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == self.identity.name - ).all() + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .all() + ) if db_providers and len(db_providers) != 0: for db_provider in db_providers: @@ -167,16 +147,16 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: assistant_tool = self._parse_tool_bundle(tool) assistant_tool.is_team_authorization = True tools.append(assistant_tool) - + self.tools = tools return tools - + def get_tool(self, tool_name: str) -> ApiTool: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: self.get_tools() @@ -185,4 +165,4 @@ def get_tool(self, tool_name: str) -> ApiTool: if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 2d472e0a93c866..01544d7e562d52 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,11 +11,12 @@ logger = logging.getLogger(__name__) + class AppToolProviderEntity(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.APP - + def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass @@ -23,9 +24,13 @@ def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) - pass def get_tools(self, user_id: str) -> list[Tool]: - db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter( - PublishedAppTool.user_id == user_id, - ).all() + db_tools: list[PublishedAppTool] = ( + db.session.query(PublishedAppTool) + .filter( + PublishedAppTool.user_id == user_id, + ) + .all() + ) if not db_tools or len(db_tools) == 0: return [] @@ -34,23 +39,17 @@ def get_tools(self, user_id: str) -> list[Tool]: for db_tool in db_tools: tool = { - 'identity': { - 'author': db_tool.author, - 'name': db_tool.tool_name, - 'label': { - 'en_US': db_tool.tool_name, - 'zh_Hans': db_tool.tool_name - }, - 'icon': '' + "identity": { + "author": db_tool.author, + "name": db_tool.tool_name, + "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name}, + "icon": "", }, - 'description': { - 'human': { - 'en_US': db_tool.description_i18n.en_US, - 'zh_Hans': db_tool.description_i18n.zh_Hans - }, - 'llm': db_tool.llm_description + "description": { + "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans}, + "llm": db_tool.llm_description, }, - 'parameters': [] + "parameters": [], } # get app from db app: App = db_tool.app @@ -64,52 +63,41 @@ def get_tools(self, user_id: str) -> list[Tool]: for input_form in user_input_form_list: # get type form_type = input_form.keys()[0] - default = input_form[form_type]['default'] - required = input_form[form_type]['required'] - label = input_form[form_type]['label'] - variable_name = input_form[form_type]['variable_name'] - options = input_form[form_type].get('options', []) - if form_type == 'paragraph' or form_type == 'text-input': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.STRING, - required=required, - default=default - )) - elif form_type == 'select': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.SELECT, - required=required, - default=default, - options=[ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in options] - )) + default = input_form[form_type]["default"] + required = input_form[form_type]["required"] + label = input_form[form_type]["label"] + variable_name = input_form[form_type]["variable_name"] + options = input_form[form_type].get("options", []) + if form_type == "paragraph" or form_type == "text-input": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.STRING, + required=required, + default=default, + ) + ) + elif form_type == "select": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ], + ) + ) tools.append(Tool(**tool)) - return tools \ No newline at end of file + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 062668fc5bf8bf..5c10f72fdaed01 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -10,7 +10,7 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) def name_func(provider: UserToolProvider) -> str: return provider.name diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df4ff3..e0cbbd2992a515 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -6,6 +6,6 @@ class AIPPTProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__") except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 8d6883a3b114ac..7cee8f9f799b83 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool): A tool for generating a ppt """ - _api_base_url = URL('https://co.aippt.cn/api') + _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock:Optional[Lock] = None + _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock:Optional[Lock] = None + _style_cache_lock: Optional[Lock] = None _task = {} _task_type_map = { - 'auto': 1, - 'markdown': 7, + "auto": 1, + "markdown": 7, } def __init__(self, **kwargs: Any): @@ -48,65 +48,55 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - title = tool_parameters.get('title', '') + title = tool_parameters.get("title", "") if not title: - return self.create_text_message('Please provide a title for the ppt') - - model = tool_parameters.get('model', 'aippt') + return self.create_text_message("Please provide a title for the ppt") + + model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message('Please provide a model for the ppt') - - outline = tool_parameters.get('outline', '') + return self.create_text_message("Please provide a model for the ppt") + + outline = tool_parameters.get("outline", "") # create task task_id = self._create_task( - type=self._task_type_map['auto' if not outline else 'markdown'], + type=self._task_type_map["auto" if not outline else "markdown"], title=title, content=outline, - user_id=user_id + user_id=user_id, ) # get suit - color = tool_parameters.get('color') - style = tool_parameters.get('style') + color = tool_parameters.get("color") + style = tool_parameters.get("style") - if color == '__default__': - color_id = '' + if color == "__default__": + color_id = "" else: - color_id = int(color.split('-')[1]) + color_id = int(color.split("-")[1]) - if style == '__default__': - style_id = '' + if style == "__default__": + style_id = "" else: - style_id = int(style.split('-')[1]) + style_id = int(style.split("-")[1]) suit_id = self._get_suit(style_id=style_id, colour_id=color_id) # generate outline if not outline: - self._generate_outline( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_outline(task_id=task_id, model=model, user_id=user_id) # generate content - self._generate_content( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_content(task_id=task_id, model=model, user_id=user_id) # generate ppt - _, ppt_url = self._generate_ppt( - task_id=task_id, - suit_id=suit_id, - user_id=user_id - ) + _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message('''the ppt has been created successfully,''' - f'''the ppt url is {ppt_url}''' - '''please give the ppt url to user and direct user to download it.''') + return self.create_text_message( + """the ppt has been created successfully,""" + f"""the ppt url is {ppt_url}""" + """please give the ppt url to user and direct user to download it.""" + ) def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: """ @@ -119,129 +109,121 @@ def _create_task(self, type: int, title: str, content: str, user_id: str) -> str :return: the task ID """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + str(self._api_base_url / "ai" / "chat" / "v2" / "task"), headers=headers, - files={ - 'type': ('', str(type)), - 'title': ('', title), - 'content': ('', content) - } + files={"type": ("", str(type)), "title": ("", title), "content": ("", content)}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to create task: {response.get("msg")}') - return response.get('data', {}).get('id') - + return response.get("data", {}).get("id") + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "outline" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "outline" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - outline = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + outline = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - outline += data.get('content', '') + outline += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate outline: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate outline: {data}") + return outline - + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "content" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "content" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - if model == 'aippt': - content = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + if model == "aippt": + content = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - content += data.get('content', '') + content += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate content: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate content: {data}") + return content - elif model == 'wenxin': + elif model == "wenxin": response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate content: {response.get("msg")}') - - return response.get('data', '') - - return '' + + return response.get("data", "") + + return "" def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: """ @@ -252,83 +234,73 @@ def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: :return: the cover url of the ppt and the ppt url """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'design' / 'v2' / 'save'), + str(self._api_base_url / "design" / "v2" / "save"), headers=headers, - data={ - 'task_id': task_id, - 'template_id': suit_id - } + data={"task_id": task_id, "template_id": suit_id}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - id = response.get('data', {}).get('id') - cover_url = response.get('data', {}).get('cover_url') + + id = response.get("data", {}).get("id") + cover_url = response.get("data", {}).get("cover_url") response = post( - str(self._api_base_url / 'download' / 'export' / 'file'), + str(self._api_base_url / "download" / "export" / "file"), headers=headers, - data={ - 'id': id, - 'format': 'ppt', - 'files_to_zip': False, - 'edit': True - } + data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - export_code = response.get('data') + + export_code = response.get("data") if not export_code: - raise Exception('Failed to generate ppt, the export code is empty') - + raise Exception("Failed to generate ppt, the export code is empty") + current_iteration = 0 while current_iteration < 50: # get ppt url response = post( - str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + str(self._api_base_url / "download" / "export" / "file" / "result"), headers=headers, - data={ - 'task_key': export_code - } + data={"task_key": export_code}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - if response.get('msg') == '导出中': + + if response.get("msg") == "导出中": current_iteration += 1 sleep(2) continue - - ppt_url = response.get('data', []) + + ppt_url = response.get("data", []) if len(ppt_url) == 0: - raise Exception('Failed to generate ppt, the ppt url is empty') - + raise Exception("Failed to generate ppt, the ppt url is empty") + return cover_url, ppt_url[0] - - raise Exception('Failed to generate ppt, the export is timeout') - + + raise Exception("Failed to generate ppt, the export is timeout") + @classmethod def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: """ @@ -337,53 +309,43 @@ def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: :param credentials: the credentials :return: the API token """ - access_key = credentials['aippt_access_key'] - secret_key = credentials['aippt_secret_key'] + access_key = credentials["aippt_access_key"] + secret_key = credentials["aippt_secret_key"] - cache_key = f'{access_key}#@#{user_id}' + cache_key = f"{access_key}#@#{user_id}" with cls._api_token_cache_lock: # clear expired tokens now = time() for key in list(cls._api_token_cache.keys()): - if cls._api_token_cache[key]['expire'] < now: + if cls._api_token_cache[key]["expire"] < now: del cls._api_token_cache[key] if cache_key in cls._api_token_cache: - return cls._api_token_cache[cache_key]['token'] - + return cls._api_token_cache[cache_key]["token"] + # get token headers = { - 'x-api-key': access_key, - 'x-timestamp': str(int(now)), - 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + "x-api-key": access_key, + "x-timestamp": str(int(now)), + "x-signature": cls._calculate_sign(access_key, secret_key, int(now)), } - param = { - 'uid': user_id, - 'channel': '' - } + param = {"uid": user_id, "channel": ""} - response = get( - str(cls._api_base_url / 'grant' / 'token'), - params=param, - headers=headers - ) + response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') + raise Exception(f"Failed to connect to aippt: {response.text}") response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - token = response.get('data', {}).get('token') - expire = response.get('data', {}).get('time_expire') + + token = response.get("data", {}).get("token") + expire = response.get("data", {}).get("time_expire") with cls._api_token_cache_lock: - cls._api_token_cache[cache_key] = { - 'token': token, - 'expire': now + expire - } + cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire} return token @@ -391,11 +353,9 @@ def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), - msg=f'GET@/api/grant/token/@{timestamp}'.encode(), - digestmod=sha1 + key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 ).digest() - ).decode('utf-8') + ).decode("utf-8") @classmethod def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: @@ -408,47 +368,46 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di # clear expired styles now = time() for key in list(cls._style_cache.keys()): - if cls._style_cache[key]['expire'] < now: + if cls._style_cache[key]["expire"] < now: del cls._style_cache[key] key = f'{credentials["aippt_access_key"]}#@#{user_id}' if key in cls._style_cache: - return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] headers = { - 'x-channel': '', - 'x-api-key': credentials['aippt_access_key'], - 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + "x-channel": "", + "x-api-key": credentials["aippt_access_key"], + "x-token": cls._get_api_token(credentials=credentials, user_id=user_id), } - response = get( - str(cls._api_base_url / 'template_component' / 'suit' / 'select'), - headers=headers - ) + response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - colors = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('name'), - 'en_name': item.get('en_name', item.get('name')), - } for item in response.get('data', {}).get('colour') or []] - styles = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('title'), - } for item in response.get('data', {}).get('suit_style') or []] - with cls._style_cache_lock: - cls._style_cache[key] = { - 'colors': colors, - 'styles': styles, - 'expire': now + 60 * 60 + colors = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("name"), + "en_name": item.get("en_name", item.get("name")), + } + for item in response.get("data", {}).get("colour") or [] + ] + styles = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("title"), } + for item in response.get("data", {}).get("suit_style") or [] + ] + + with cls._style_cache_lock: + cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60} return colors, styles @@ -459,44 +418,39 @@ def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - raise Exception('Please provide aippt credentials') + if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + raise Exception("Please provide aippt credentials") return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) - + def _get_suit(self, style_id: int, colour_id: int) -> int: """ Get suit """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'search'), + str(self._api_base_url / "template_component" / "suit" / "search"), headers=headers, - params={ - 'style_id': style_id, - 'colour_id': colour_id, - 'page': 1, - 'page_size': 1 - } + params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - if len(response.get('data', {}).get('list') or []) > 0: - return response.get('data', {}).get('list')[0].get('id') - - raise Exception('Failed to get suit, the suit does not exist, please check the style and color') - + + if len(response.get("data", {}).get("list") or []) > 0: + return response.get("data", {}).get("list")[0].get("id") + + raise Exception("Failed to get suit, the suit does not exist, please check the style and color") + def get_runtime_parameters(self) -> list[ToolParameter]: """ Get runtime parameters @@ -504,43 +458,40 @@ def get_runtime_parameters(self) -> list[ToolParameter]: Override this method to add runtime parameters to the tool. """ try: - colors, styles = self.get_styles(user_id='__dify_system__') + colors, styles = self.get_styles(user_id="__dify_system__") except Exception as e: - colors, styles = [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ], [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ] + colors, styles = ( + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + ) return [ ToolParameter( - name='color', - label=I18nObject(zh_Hans='颜色', en_US='Color'), - human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + name="color", + label=I18nObject(zh_Hans="颜色", en_US="Color"), + human_description=I18nObject(zh_Hans="颜色", en_US="Color"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=colors[0]['id'], + default=colors[0]["id"], options=[ ToolParameterOption( - value=color['id'], - label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) - ) for color in colors - ] + value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"]) + ) + for color in colors + ], ), ToolParameter( - name='style', - label=I18nObject(zh_Hans='风格', en_US='Style'), - human_description=I18nObject(zh_Hans='风格', en_US='Style'), + name="style", + label=I18nObject(zh_Hans="风格", en_US="Style"), + human_description=I18nObject(zh_Hans="风格", en_US="Style"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=styles[0]['id'], + default=styles[0]["id"], options=[ - ToolParameterOption( - value=style['id'], - label=I18nObject(zh_Hans=style['name'], en_US=style['name']) - ) for style in styles - ] + ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"])) + for style in styles + ], ), - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py index 01f2acfb5b6310..a84630e5aa990a 100644 --- a/api/core/tools/provider/builtin/alphavantage/alphavantage.py +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -13,7 +13,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "code": "AAPL", # Apple Inc. }, diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py index 5c379b746d5db9..d06611acd05d1d 100644 --- a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -9,17 +9,16 @@ class QueryStockTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - stock_code = tool_parameters.get('code', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + stock_code = tool_parameters.get("code", "") if not stock_code: - return self.create_text_message('Please tell me your stock code') + return self.create_text_message("Please tell me your stock code") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Alpha Vantage API key is required.") params = { @@ -27,7 +26,7 @@ def _invoke(self, "symbol": stock_code, "outputsize": "compact", "datatype": "json", - "apikey": self.runtime.credentials['api_key'] + "apikey": self.runtime.credentials["api_key"], } response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) response.raise_for_status() @@ -35,15 +34,15 @@ def _invoke(self, return self.create_json_message(result) def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: - result = response.get('Time Series (Daily)', {}) + result = response.get("Time Series (Daily)", {}) if not result: return {} stock_result = {} for k, v in result.items(): stock_result[k] = {} - stock_result[k]['open'] = v.get('1. open') - stock_result[k]['high'] = v.get('2. high') - stock_result[k]['low'] = v.get('3. low') - stock_result[k]['close'] = v.get('4. close') - stock_result[k]['volume'] = v.get('5. volume') + stock_result[k]["open"] = v.get("1. open") + stock_result[k]["high"] = v.get("2. high") + stock_result[k]["low"] = v.get("3. low") + stock_result[k]["close"] = v.get("4. close") + stock_result[k]["volume"] = v.get("5. volume") return stock_result diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 707fc69be30cee..ebb2d1a8c47be9 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index ce28373880ba18..98d82c233e0714 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -8,6 +8,8 @@ from core.tools.tool.builtin_tool import BuiltinTool logger = logging.getLogger(__name__) + + class ArxivAPIWrapper(BaseModel): """Wrapper around ArxivAPI. @@ -86,11 +88,13 @@ def run(self, query: str) -> str: class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") - + + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Arxiv search tool with the given user ID and tool parameters. @@ -102,13 +106,13 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + arxiv = ArxivAPIWrapper() - + response = arxiv.run(query) - + return self.create_text_message(self.summary(user_id=user_id, content=response)) diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py index 13ede9601509f5..f81b5dbd27d17c 100644 --- a/api/core/tools/provider/builtin/aws/aws.py +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -11,15 +11,14 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "sagemaker_endpoint" : "", + "sagemaker_endpoint": "", "query": "misaka mikoto", - "candidate_texts" : "hello$$$hello world", - "topk" : 5, - "aws_region" : "" + "candidate_texts": "hello$$$hello world", + "topk": 5, + "aws_region": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 06fcf8a453edba..d6a65b1708281c 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -12,6 +12,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class GuardrailParameters(BaseModel): guardrail_id: str = Field(..., description="The identifier of the guardrail") guardrail_version: str = Field(..., description="The version of the guardrail") @@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel): text: str = Field(..., description="The text to apply the guardrail to") aws_region: str = Field(..., description="AWS region for the Bedrock client") + class ApplyGuardrailTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ try: # Validate and parse input parameters params = GuardrailParameters(**tool_parameters) - + # Initialize AWS client - bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region) # Apply guardrail response = bedrock_client.apply_guardrail( guardrailIdentifier=params.guardrail_id, guardrailVersion=params.guardrail_version, source=params.source, - content=[{"text": {"text": params.text}}] + content=[{"text": {"text": params.text}}], ) - + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") # Check for empty response if not response: return self.create_text_message(text="Received empty response from AWS Bedrock.") - + # Process the result action = response.get("action", "No action specified") outputs = response.get("outputs", []) @@ -58,9 +59,11 @@ def _invoke(self, formatted_assessments = [] for assessment in assessments: for policy_type, policy_data in assessment.items(): - if isinstance(policy_data, dict) and 'topics' in policy_data: - for topic in policy_data['topics']: - formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + if isinstance(policy_data, dict) and "topics" in policy_data: + for topic in policy_data["topics"]: + formatted_assessments.append( + f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}" + ) else: formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") @@ -68,19 +71,19 @@ def _invoke(self, result += f"Output: {output}\n " if formatted_assessments: result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " -# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + # result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" return self.create_text_message(text=result) except BotoCoreError as e: - error_message = f'AWS service error: {str(e)}' + error_message = f"AWS service error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except json.JSONDecodeError as e: - error_message = f'JSON parsing error: {str(e)}' + error_message = f"JSON parsing error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except Exception as e: - error_message = f'An unexpected error occurred: {str(e)}' + error_message = f"An unexpected error occurred: {str(e)}" logger.error(error_message, exc_info=True) - return self.create_text_message(text=error_message) \ No newline at end of file + return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 005ba3deb53311..48755753ace7c1 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): - msg = { - "src_content":text_content, - "src_lang": src_lang, - "dest_lang":dest_lang, + msg = { + "src_content": text_content, + "src_lang": src_lang, + "dest_lang": dest_lang, "dictionary_id": dictionary_name, - "request_type" : request_type, - "model_id" : model_id + "request_type": request_type, + "model_id": model_id, } - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("unicode_escape") return response_str - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") line = 1 - text_content = tool_parameters.get('text_content', '') + text_content = tool_parameters.get("text_content", "") if not text_content: - return self.create_text_message('Please input text_content') - + return self.create_text_message("Please input text_content") + line = 2 - src_lang = tool_parameters.get('src_lang', '') + src_lang = tool_parameters.get("src_lang", "") if not src_lang: - return self.create_text_message('Please input src_lang') - + return self.create_text_message("Please input src_lang") + line = 3 - dest_lang = tool_parameters.get('dest_lang', '') + dest_lang = tool_parameters.get("dest_lang", "") if not dest_lang: - return self.create_text_message('Please input dest_lang') - + return self.create_text_message("Please input dest_lang") + line = 4 - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - + return self.create_text_message("Please input lambda_name") + line = 5 - request_type = tool_parameters.get('request_type', '') + request_type = tool_parameters.get("request_type", "") if not request_type: - return self.create_text_message('Please input request_type') - + return self.create_text_message("Please input request_type") + line = 6 - model_id = tool_parameters.get('model_id', '') + model_id = tool_parameters.get("model_id", "") if not model_id: - return self.create_text_message('Please input model_id') + return self.create_text_message("Please input model_id") line = 7 - dictionary_name = tool_parameters.get('dictionary_name', '') + dictionary_name = tool_parameters.get("dictionary_name", "") if not dictionary_name: - return self.create_text_message('Please input dictionary_name') - - result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + return self.create_text_message("Please input dictionary_name") + + result = self._invoke_lambda( + text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name + ) return self.create_text_message(text=result) except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index bb7f6840b8539f..f43f3b6fe05694 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: - msg = { - "body": yaml_content - } + msg = {"body": yaml_content} logger.info(json.dumps(msg)) - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("utf-8") resp_json = json.loads(response_str) logger.info(resp_json) - if resp_json['statusCode'] != 200: + if resp_json["statusCode"] != 200: raise Exception(f"Invalid status code: {response_str}") - return resp_json['body'] + return resp_json["body"] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region + aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") - yaml_content = tool_parameters.get('yaml_content', '') + yaml_content = tool_parameters.get("yaml_content", "") if not yaml_content: - return self.create_text_message('Please input yaml_content') + return self.create_text_message("Please input yaml_content") - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}') - + return self.create_text_message("Please input lambda_name") + logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}") + result = self._invoke_lambda(lambda_name, yaml_content) logger.debug(result) - + return self.create_text_message(result) except Exception as e: - return self.create_text_message(f'Exception: {str(e)}') + return self.create_text_message(f"Exception: {str(e)}") - console_handler.flush() \ No newline at end of file + console_handler.flush() diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 2b3a3eaad637b0..3c35b65e66bada 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -9,37 +9,33 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - topk:int = None + sagemaker_endpoint: str = None + topk: int = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -47,25 +43,25 @@ def _invoke(self, line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 if not self.topk: - self.topk = tool_parameters.get('topk', 5) + self.topk = tool_parameters.get("topk", 5) line = 3 - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + line = 4 - candidate_texts = tool_parameters.get('candidate_texts') + candidate_texts = tool_parameters.get("candidate_texts") if not candidate_texts: - return self.create_text_message('Please input candidate_texts') - + return self.create_text_message("Please input candidate_texts") + line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content') for item in candidate_docs ] + docs = [item.get("content") for item in candidate_docs] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) @@ -75,10 +71,10 @@ def _invoke(self, candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 9 - return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ] - + return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index a100e6223072f0..bceeaab7453d94 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -14,82 +14,88 @@ class TTSModelType(Enum): CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" + class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - s3_client : Any = None - comprehend_client : Any = None - - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + sagemaker_endpoint: str = None + s3_client: Any = None + comprehend_client: Any = None - response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] - return map_dict.get(language_code, '<|zh|>') + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response["Languages"][0]["LanguageCode"] + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') + self.comprehend_client = boto3.client("comprehend") if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") - tts_text = tool_parameters.get('tts_text') - tts_infer_type = tool_parameters.get('tts_infer_type') + tts_text = tool_parameters.get("tts_text") + tts_infer_type = tool_parameters.get("tts_infer_type") - voice = tool_parameters.get('voice') - mock_voice_audio = tool_parameters.get('mock_voice_audio') - mock_voice_text = tool_parameters.get('mock_voice_text') - voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt') - payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt) + voice = tool_parameters.get("voice") + mock_voice_audio = tool_parameters.get("mock_voice_audio") + mock_voice_text = tool_parameters.get("mock_voice_text") + voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt") + payload = self._build_tts_payload( + tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt + ) result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) - return self.create_text_message(text=result['s3_presign_url']) - + return self.create_text_message(text=result["s3_presign_url"]) + except Exception as e: - return self.create_text_message(f'Exception {str(e)}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 2981a54d3c5716..1fab0d03a28ff3 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "square", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 2ffdd38b72bc22..09f30a59d6c6cd 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -9,47 +9,48 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ client = AzureOpenAI( - api_version=self.runtime.credentials['azure_openai_api_version'], - azure_endpoint=self.runtime.credentials['azure_openai_base_url'], - api_key=self.runtime.credentials['azure_openai_api_key'], + api_version=self.runtime.credentials["azure_openai_api_version"], + azure_endpoint=self.runtime.credentials["azure_openai_base_url"], + api_key=self.runtime.credentials["azure_openai_api_key"], ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} # call openapi dalle3 - model = self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials["azure_openai_api_model_name"] response = client.images.generate( prompt=prompt, model=model, @@ -58,21 +59,25 @@ def _invoke(self, extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) - result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) + result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index f85a5ed4722523..0d9613c0cf64b1 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -8,142 +8,135 @@ class BingSearchTool(BuiltinTool): - url: str = 'https://api.bing.microsoft.com/v7.0/search' - - def _invoke_bing(self, - user_id: str, - server_url: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, - filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url: str = "https://api.bing.microsoft.com/v7.0/search" + + def _invoke_bing( + self, + user_id: str, + server_url: str, + subscription_key: str, + query: str, + limit: int, + result_type: str, + market: str, + lang: str, + filters: list[str], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke bing search + invoke bing search """ - market_code = f'{lang}-{market}' - accept_language = f'{lang},{market_code};q=0.9' - headers = { - 'Ocp-Apim-Subscription-Key': subscription_key, - 'Accept-Language': accept_language - } + market_code = f"{lang}-{market}" + accept_language = f"{lang},{market_code};q=0.9" + headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} query = quote(query) server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - + raise Exception(f"Error {response.status_code}: {response.text}") + response = response.json() - search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] - related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] - entities = response['entities']['value'] if 'entities' in response else [] - news = response['news']['value'] if 'news' in response else [] - computation = response['computation']['value'] if 'computation' in response else None + search_results = response["webPages"]["value"][:limit] if "webPages" in response else [] + related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else [] + entities = response["entities"]["value"] if "entities" in response else [] + news = response["news"]["value"] if "news" in response else [] + computation = response["computation"]["value"] if "computation" in response else None - if result_type == 'link': + if result_type == "link": results = [] if search_results: for result in search_results: url = f': {result["url"]}' if "url" in result else "" - results.append(self.create_text_message( - text=f'{result["name"]}{url}' - )) - + results.append(self.create_text_message(text=f'{result["name"]}{url}')) if entities: for entity in entities: url = f': {entity["url"]}' if "url" in entity else "" - results.append(self.create_text_message( - text=f'{entity.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) if news: for news_item in news: url = f': {news_item["url"]}' if "url" in news_item else "" - results.append(self.create_text_message( - text=f'{news_item.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) if related_searches: for related in related_searches: url = f': {related["displayText"]}' if "displayText" in related else "" - results.append(self.create_text_message( - text=f'{related.get("displayText", "")}{url}' - )) - + results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) + return results else: # construct text - text = '' + text = "" if search_results: for i, result in enumerate(search_results): text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' - if computation and 'expression' in computation and 'value' in computation: - text += '\nComputation:\n' + if computation and "expression" in computation and "value" in computation: + text += "\nComputation:\n" text += f'{computation["expression"]} = {computation["value"]}\n' if entities: - text += '\nEntities:\n' + text += "\nEntities:\n" for entity in entities: url = f'- {entity["url"]}' if "url" in entity else "" text += f'{entity.get("name", "")}{url}\n' if news: - text += '\nNews:\n' + text += "\nNews:\n" for news_item in news: url = f'- {news_item["url"]}' if "url" in news_item else "" text += f'{news_item.get("name", "")}{url}\n' if related_searches: - text += '\n\nRelated Searches:\n' + text += "\n\nRelated Searches:\n" for related in related_searches: url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" text += f'{related.get("displayText", "")}{url}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key') + key = credentials.get("subscription_key") if not key: - raise Exception('subscription_key is required') - - server_url = credentials.get('server_url') + raise Exception("subscription_key is required") + + server_url = credentials.get("server_url") if not server_url: server_url = self.url - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' + raise Exception("query is required") - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if credentials.get('allow_entities', False): - filter.append('Entities') + if credentials.get("allow_entities", False): + filter.append("Entities") - if credentials.get('allow_computation', False): - filter.append('Computation') + if credentials.get("allow_computation", False): + filter.append("Computation") - if credentials.get('allow_news', False): - filter.append('News') + if credentials.get("allow_news", False): + filter.append("News") - if credentials.get('allow_related_searches', False): - filter.append('RelatedSearches') + if credentials.get("allow_related_searches", False): + filter.append("RelatedSearches") - if credentials.get('allow_web_pages', False): - filter.append('WebPages') + if credentials.get("allow_web_pages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + self._invoke_bing( - user_id='test', + user_id="test", server_url=server_url, subscription_key=key, query=query, @@ -151,50 +144,51 @@ def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dic result_type=result_type, market=market, lang=lang, - filters=filter + filters=filter, ) - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - key = self.runtime.credentials.get('subscription_key', None) + key = self.runtime.credentials.get("subscription_key", None) if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) + raise Exception("subscription_key is required") + + server_url = self.runtime.credentials.get("server_url", None) if not server_url: server_url = self.url - - query = tool_parameters.get('query') + + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + raise Exception("query is required") + + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') + if tool_parameters.get("enable_computation", False): + filter.append("Computation") + if tool_parameters.get("enable_entities", False): + filter.append("Entities") + if tool_parameters.get("enable_news", False): + filter.append("News") + if tool_parameters.get("enable_related_search", False): + filter.append("RelatedSearches") + if tool_parameters.get("enable_webpages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + return self._invoke_bing( user_id=user_id, server_url=server_url, @@ -204,5 +198,5 @@ def _invoke(self, result_type=result_type, market=market, lang=lang, - filters=filter - ) \ No newline at end of file + filters=filter, + ) diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e5eada80ee430f..c24ee67334083b 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -13,11 +13,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 21cbf2c7dae11d..94a4d928445319 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -37,7 +37,7 @@ def run(self, query: str) -> str: for item in web_search_results ] return json.dumps(final_results) - + def _search_request(self, query: str) -> list[dict]: headers = { "X-Subscription-Token": self.api_key, @@ -55,6 +55,7 @@ def _search_request(self, query: str) -> list[dict]: return response.json().get("web", {}).get("results", []) + class BraveSearch(BaseModel): """Tool that queries the BraveSearch.""" @@ -67,9 +68,7 @@ class BraveSearch(BaseModel): search_wrapper: BraveSearchWrapper @classmethod - def from_api_key( - cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any - ) -> "BraveSearch": + def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": """Create a tool from an api key. Args: @@ -90,6 +89,7 @@ def _run( """Use the tool.""" return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -106,12 +106,12 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') - count = tool_parameters.get('count', 3) - api_key = self.runtime.credentials['brave_search_api_key'] + query = tool_parameters.get("query", "") + count = tool_parameters.get("count", 3) + api_key = self.runtime.credentials["brave_search_api_key"] if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) @@ -121,4 +121,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(f"No results found for '{query}' in Tavily") else: return self.create_text_message(text=results) - diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 0865bc700ac91c..8a24d33428599c 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -7,16 +7,34 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController # use a business theme -plt.style.use('seaborn-v0_8-darkgrid') -plt.rcParams['axes.unicode_minus'] = False +plt.style.use("seaborn-v0_8-darkgrid") +plt.rcParams["axes.unicode_minus"] = False + def init_fonts(): fonts = findSystemFonts() popular_unicode_fonts = [ - 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', - 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', - 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' + "Arial Unicode MS", + "DejaVu Sans", + "DejaVu Sans Mono", + "DejaVu Serif", + "FreeMono", + "FreeSans", + "FreeSerif", + "Liberation Mono", + "Liberation Sans", + "Liberation Serif", + "Noto Mono", + "Noto Sans", + "Noto Serif", + "Open Sans", + "Roboto", + "Source Code Pro", + "Source Sans Pro", + "Source Serif Pro", + "Ubuntu", + "Ubuntu Mono", ] supported_fonts = [] @@ -25,21 +43,23 @@ def init_fonts(): try: font = TTFont(font_path) # get family name - family_name = font['name'].getName(1, 3, 1).toUnicode() + family_name = font["name"].getName(1, 3, 1).toUnicode() if family_name in popular_unicode_fonts: supported_fonts.append(family_name) except: pass - plt.rcParams['font.family'] = 'sans-serif' + plt.rcParams["font.family"] = "sans-serif" # sort by order of popular_unicode_fonts for font in popular_unicode_fonts: if font in supported_fonts: - plt.rcParams['font.sans-serif'] = font + plt.rcParams["font.sans-serif"] = font break - + + init_fonts() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: @@ -48,11 +68,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "data": "1,3,5,7,9,2,4,6,8,10", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 749ec761c692ec..3a47c0cfc0d47f 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -8,12 +8,13 @@ class BarChartTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -21,29 +22,27 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.bar(axis, data) else: ax.bar(range(len(data)), data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the bar chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the bar chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 608bd6623cf71c..39e8caac7ef609 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,18 +8,19 @@ class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None @@ -32,20 +33,18 @@ def _invoke(self, flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.plot(axis, data) else: ax.plot(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the linear chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the linear chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 4c551229e98565..2c3b8a733eac9a 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,15 +8,16 @@ class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') - categories = tool_parameters.get('categories') or None + return self.create_text_message("Please input data") + data = data.split(";") + categories = tool_parameters.get("categories") or None # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -27,7 +28,7 @@ def _invoke(self, flg, ax = plt.subplots() if categories: - categories = categories.split(';') + categories = categories.split(";") if len(categories) != len(data): categories = None @@ -37,12 +38,11 @@ def _invoke(self, ax.pie(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the pie chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + self.create_text_message("the pie chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), + ] diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 37645bf0d0189d..017fe548f76e68 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -8,15 +8,15 @@ class SimpleCode(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ - invoke simple code + invoke simple code """ - language = tool_parameters.get('language', CodeLanguage.PYTHON3) - code = tool_parameters.get('code', '') + language = tool_parameters.get("language", CodeLanguage.PYTHON3) + code = tool_parameters.get("code", "") if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: - raise ValueError(f'Only python3 and javascript are supported, not {language}') - - result = CodeExecutor.execute_code(language, '', code) + raise ValueError(f"Only python3 and javascript are supported, not {language}") - return self.create_text_message(result) \ No newline at end of file + result = CodeExecutor.execute_code(language, "", code) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py index 801817ec06ed36..6941ce86495693 100644 --- a/api/core/tools/provider/builtin/cogview/cogview.py +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -1,4 +1,5 @@ -""" Provide the input parameters type for the cogview provider class """ +"""Provide the input parameters type for the cogview provider class""" + from typing import Any from core.tools.errors import ToolProviderCredentialValidationError @@ -7,7 +8,8 @@ class COGVIEWProvider(BuiltinToolProviderController): - """ cogview provider """ + """cogview provider""" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CogView3Tool().fork_tool_runtime( @@ -15,13 +17,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", "size": "square", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) from e - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 89ffcf3347878a..9776bd7dd1c02d 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -7,43 +7,42 @@ class CogView3Tool(BuiltinTool): - """ CogView3 Tool """ + """CogView3 Tool""" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke CogView3 tool """ client = ZhipuAI( - base_url=self.runtime.credentials['zhipuai_base_url'], - api_key=self.runtime.credentials['zhipuai_api_key'], + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], ) size_mapping = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = size_mapping[tool_parameters.get('size', 'square')] + size = size_mapping[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} response = client.images.generations( prompt=prompt, model="cogview-3", @@ -52,18 +51,22 @@ def _invoke(self, extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py index 404e483e0d23ab..8ba3c1b48ae6d7 100644 --- a/api/core/tools/provider/builtin/crossref/crossref.py +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -11,9 +11,9 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "doi": '10.1007/s00894-022-05373-8', + "doi": "10.1007/s00894-022-05373-8", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py index a43c0989e40613..746139dd69d27b 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_doi.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool): """ Tool for querying the metadata of a publication using its DOI. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - doi = tool_parameters.get('doi') + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get("doi") if not doi: - raise ToolParameterValidationError('doi is required.') + raise ToolParameterValidationError("doi is required.") # doc: https://github.com/CrossRef/rest-api-doc url = f"https://api.crossref.org/works/{doi}" response = requests.get(url) response.raise_for_status() response = response.json() - message = response.get('message', {}) + message = response.get("message", {}) return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py index 946aa6dc947a20..e2452381832938 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_title.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int: Convert a time string to seconds. example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 """ - time_str = time_str.lower().strip().replace(' ', '') + time_str = time_str.lower().strip().replace(" ", "") seconds = 0 - if 'h' in time_str: - hours, time_str = time_str.split('h') + if "h" in time_str: + hours, time_str = time_str.split("h") seconds += int(hours) * 3600 - if 'm' in time_str: - minutes, time_str = time_str.split('m') + if "m" in time_str: + minutes, time_str = time_str.split("m") seconds += int(minutes) * 60 - if 's' in time_str: - seconds += int(time_str.replace('s', '')) + if "s" in time_str: + seconds += int(time_str.replace("s", "")) return seconds @@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI: Tool for querying the metadata of a publication using its title. Crossref API doc: https://github.com/CrossRef/rest-api-doc """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" rate_limit: int = 50 rate_interval: float = 1 @@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI: def __init__(self, mailto: str): self.mailto = mailto - def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def _query( + self, + query: str, + rows: int = 5, + offset: int = 0, + sort: str = "relevance", + order: str = "desc", + fuzzy_query: bool = False, + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -47,33 +56,37 @@ def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'releva :param order: the sort order :param fuzzy_query: whether to return all items that match the query """ - url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + url = self.query_url_template.format( + query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto + ) response = requests.get(url) response.raise_for_status() - rate_limit = int(response.headers['x-ratelimit-limit']) + rate_limit = int(response.headers["x-ratelimit-limit"]) # convert time string to seconds - rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"]) self.rate_limit = rate_limit self.rate_interval = rate_interval response = response.json() - if response['status'] != 'ok': + if response["status"] != "ok": return [] - message = response['message'] + message = response["message"] if fuzzy_query: # fuzzy query return all items - return message['items'] + return message["items"] else: - for paper in message['items']: - title = paper['title'][0] + for paper in message["items"]: + title = paper["title"][0] if title.lower() != query.lower(): continue return [paper] return [] - def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def query( + self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -89,7 +102,14 @@ def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = results = [] for i in range(query_times): - result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + result = self._query( + query, + rows=self.rate_limit, + offset=i * self.rate_limit, + sort=sort, + order=order, + fuzzy_query=fuzzy_query, + ) if fuzzy_query: results.extend(result) else: @@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool): """ Tool for querying the metadata of a publication using its title. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query') - fuzzy_query = tool_parameters.get('fuzzy_query', False) - rows = tool_parameters.get('rows', 3) - sort = tool_parameters.get('sort', 'relevance') - order = tool_parameters.get('order', 'desc') - mailto = self.runtime.credentials['mailto'] + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query") + fuzzy_query = tool_parameters.get("fuzzy_query", False) + rows = tool_parameters.get("rows", 3) + sort = tool_parameters.get("sort", "relevance") + order = tool_parameters.get("order", "desc") + mailto = self.runtime.credentials["mailto"] result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 1c8019364de9d2..5bd16e49e85e29 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -13,13 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "small", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index 9e9f32d42970a6..ac7e3949117ab1 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,59 +9,58 @@ class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'small': '256x256', - 'medium': '512x512', - 'large': '1024x1024', + "small": "256x256", + "medium": "512x512", + "large": "1024x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - + return self.create_text_message("Please input prompt") + # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'large')] + size = SIZE_MAPPING[tool_parameters.get("size", "large")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # call openapi dalle2 - response = client.images.generate( - prompt=prompt, - model='dall-e-2', - size=size, - n=n, - response_format='b64_json' - ) + response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json") result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 4f5033dd7fe596..2d62cf608f46cb 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -10,69 +10,64 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # call openapi dalle3 response = client.images.generate( - prompt=prompt, - model='dall-e-3', - size=size, - n=n, - style=style, - quality=quality, - response_format='b64_json' + prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json" ) result = [] for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) - blob_message = self.create_blob_message(blob=blob_image, - meta={'mime_type': mime_type}, - save_as=self.VARIABLE_KEY.IMAGE.value) + blob_message = self.create_blob_message( + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value + ) result.append(blob_message) return result @@ -86,7 +81,7 @@ def _decode_image(base64_image: str) -> tuple[str, bytes]: :return: A tuple containing the MIME type and the decoded image bytes """ if DallE3Tool._is_plain_base64(base64_image): - return 'image/png', base64.b64decode(base64_image) + return "image/png", base64.b64decode(base64_image) else: return DallE3Tool._extract_mime_and_data(base64_image) @@ -98,7 +93,7 @@ def _is_plain_base64(encoded_str: str) -> bool: :param encoded_str: Base64 encoded image string :return: True if the string is plain base64, False otherwise """ - return not encoded_str.startswith('data:image') + return not encoded_str.startswith("data:image") @staticmethod def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: @@ -108,13 +103,13 @@ def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: :param encoded_str: Base64 encoded image string with MIME type prefix :return: A tuple containing the MIME type and the decoded image bytes """ - mime_type = encoded_str.split(';')[0].split(':')[1] - image_data_base64 = encoded_str.split(',')[1] + mime_type = encoded_str.split(";")[0].split(":")[1] + image_data_base64 = encoded_str.split(",")[1] decoded_data = base64.b64decode(image_data_base64) return mime_type, decoded_data @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 95d7939d0d9539..446c1e548935c0 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -11,7 +11,7 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "doc": "python~3.12", "topic": "library/code", @@ -19,4 +19,3 @@ def _validate_credentials(self, credentials: dict) -> None: ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py index 1a244c5db3f69a..e1effd066c611a 100644 --- a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py +++ b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py @@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel): class SearchDevDocsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the DevDocs search tool with the given user ID and tool parameters. @@ -24,13 +26,13 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - doc = tool_parameters.get('doc', '') - topic = tool_parameters.get('topic', '') + doc = tool_parameters.get("doc", "") + topic = tool_parameters.get("topic", "") if not doc: - return self.create_text_message('Please provide the documentation name.') + return self.create_text_message("Please provide the documentation name.") if not topic: - return self.create_text_message('Please provide the topic path.') + return self.create_text_message("Please provide the topic path.") url = f"https://documents.devdocs.io/{doc}/{topic}.html" response = requests.get(url) @@ -39,4 +41,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn content = response.text return self.create_text_message(self.summary(user_id=user_id, content=content)) else: - return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}") \ No newline at end of file + return self.create_text_message( + f"Failed to retrieve the documentation. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py index b4bf172131448d..5af78794f625b7 100644 --- a/api/core/tools/provider/builtin/did/did.py +++ b/api/core/tools/provider/builtin/did/did.py @@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the D-ID talks tool - TalksTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", "text_input": "Hello, welcome to use D-ID tool in Dify", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 964e82b729319e..4cad12e4ee2594 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -12,14 +12,14 @@ class DIDApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.d-id.com' + self.base_url = base_url or "https://api.d-id.com" if not self.api_key: - raise ValueError('API key is required') + raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( @@ -44,44 +44,44 @@ def _request( return None def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/talks' + endpoint = f"{self.base_url}/talks" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval) return id def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/animations' + endpoint = f"{self.base_url}/animations" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval) return id def check_did_status(self, target: str, id: str): - endpoint = f'{self.base_url}/{target}/{id}' + endpoint = f"{self.base_url}/{target}/{id}" headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request("GET", endpoint, headers=headers) if response is None: - raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + raise HTTPError(f"Failed to check status for talks {id} after multiple retries") return response def _monitor_job_status(self, target: str, id: str, poll_interval: int): while True: status = self.check_did_status(target=target, id=id) - if status['status'] == 'done': + if status["status"] == "done": return status - elif status['status'] == 'error' or status['status'] == 'rejected': + elif status["status"] == "error" or status["status"] == "rejected": raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py index e1d9de603fbb7a..bc9d17e40d2878 100644 --- a/api/core/tools/provider/builtin/did/tools/animations.py +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None config = { - 'stitch': tool_parameters.get('stitch', True), - 'mute': tool_parameters.get('mute'), - 'result_format': tool_parameters.get('result_format') or 'mp4', + "stitch": tool_parameters.get("stitch", True), + "mute": tool_parameters.get("mute"), + "result_format": tool_parameters.get("result_format") or "mp4", } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if config.get('logo_url'): - if not config.get('logo_x'): - raise ValueError('Logo X position is required when logo URL is provided') - if not config.get('logo_y'): - raise ValueError('Logo Y position is required when logo URL is provided') + if config.get("logo_url"): + if not config.get("logo_x"): + raise ValueError("Logo X position is required when logo URL is provided") + if not config.get("logo_y"): + raise ValueError("Logo Y position is required when logo URL is provided") animations_result = app.animations(params=options, wait=True) @@ -44,6 +44,6 @@ def _invoke( animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) if not animations_result: - return self.create_text_message('D-ID animations request failed.') + return self.create_text_message("D-ID animations request failed.") return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py index 06b2c4cb2f6049..d6f0c7ff179793 100644 --- a/api/core/tools/provider/builtin/did/tools/talks.py +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -10,49 +10,49 @@ class TalksTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None script = { - 'type': tool_parameters.get('script_type') or 'text', - 'input': tool_parameters.get('text_input'), - 'audio_url': tool_parameters.get('audio_url'), - 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + "type": tool_parameters.get("script_type") or "text", + "input": tool_parameters.get("text_input"), + "audio_url": tool_parameters.get("audio_url"), + "reduce_noise": tool_parameters.get("audio_reduce_noise", False), } - script = {k: v for k, v in script.items() if v is not None and v != ''} + script = {k: v for k, v in script.items() if v is not None and v != ""} config = { - 'stitch': tool_parameters.get('stitch', True), - 'sharpen': tool_parameters.get('sharpen'), - 'fluent': tool_parameters.get('fluent'), - 'result_format': tool_parameters.get('result_format') or 'mp4', - 'pad_audio': tool_parameters.get('pad_audio'), - 'driver_expressions': driver_expressions, + "stitch": tool_parameters.get("stitch", True), + "sharpen": tool_parameters.get("sharpen"), + "fluent": tool_parameters.get("fluent"), + "result_format": tool_parameters.get("result_format") or "mp4", + "pad_audio": tool_parameters.get("pad_audio"), + "driver_expressions": driver_expressions, } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'script': script, - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "script": script, + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if script.get('type') == 'audio': - script.pop('input', None) - if not script.get('audio_url'): - raise ValueError('Audio URL is required for audio script type') + if script.get("type") == "audio": + script.pop("input", None) + if not script.get("audio_url"): + raise ValueError("Audio URL is required for audio script type") - if script.get('type') == 'text': - script.pop('audio_url', None) - script.pop('reduce_noise', None) - if not script.get('input'): - raise ValueError('Text input is required for text script type') + if script.get("type") == "text": + script.pop("audio_url", None) + script.pop("reduce_noise", None) + if not script.get("input"): + raise ValueError("Text input is required for text script type") talks_result = app.talks(params=options, wait=True) @@ -60,6 +60,6 @@ def _invoke( talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) if not talks_result: - return self.create_text_message('D-ID talks request failed.') + return self.create_text_message("D-ID talks request failed.") return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py index c247c3bd6bcff0..f33ad5be59b403 100644 --- a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -13,38 +13,43 @@ class DingTalkGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - Dingtalk custom group robot API docs: - https://open.dingtalk.com/document/orgapp/custom-robot-access + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access """ - content = tool_parameters.get('content') + content = tool_parameters.get("content") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - access_token = tool_parameters.get('access_token') + access_token = tool_parameters.get("access_token") if not access_token: - return self.create_text_message('Invalid parameter access_token. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter access_token. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - sign_secret = tool_parameters.get('sign_secret') + sign_secret = tool_parameters.get("sign_secret") if not sign_secret: - return self.create_text_message('Invalid parameter sign_secret. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter sign_secret. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - msgtype = 'text' - api_url = 'https://oapi.dingtalk.com/robot/send' + msgtype = "text" + api_url = "https://oapi.dingtalk.com/robot/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'access_token': access_token, + "access_token": access_token, } self._apply_security_mechanism(params, sign_secret) @@ -53,7 +58,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] "msgtype": msgtype, "text": { "content": content, - } + }, } try: @@ -62,7 +67,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) @@ -70,14 +76,14 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): try: timestamp = str(round(time.time() * 1000)) - secret_enc = sign_secret.encode('utf-8') - string_to_sign = f'{timestamp}\n{sign_secret}' - string_to_sign_enc = string_to_sign.encode('utf-8') + secret_enc = sign_secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{sign_secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - params['timestamp'] = timestamp - params['sign'] = sign + params["timestamp"] = timestamp + params["sign"] = sign except Exception: msg = "Failed to apply security mechanism to the request." logging.exception(msg) diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 2292e89fa6ed13..8269167127b8e5 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py index 878b0d86453a2b..8bdd638f4a01d1 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py @@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "model": tool_parameters.get('model'), + "keywords": tool_parameters.get("query"), + "model": tool_parameters.get("model"), } response = DDGS().chat(**query_dict) return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index bca53f6b4bd745..396570248ae785 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get('query'), - "timelimit": tool_parameters.get('timelimit'), - "size": tool_parameters.get('size'), - "max_results": tool_parameters.get('max_results'), + "keywords": tool_parameters.get("query"), + "timelimit": tool_parameters.get("timelimit"), + "size": tool_parameters.get("size"), + "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) result = [] for res in response: - res['transfer_method'] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=res.get('image'), - save_as='', - meta=res) + res["transfer_method"] = FileTransferMethod.REMOTE_URL + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res + ) result.append(msg) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index dfaeb734d8f667..cbd65d2e7756e0 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - query = tool_parameters.get('query') - max_results = tool_parameters.get('max_results', 5) - require_summary = tool_parameters.get('require_summary', False) + query = tool_parameters.get("query") + max_results = tool_parameters.get("max_results", 5) + require_summary = tool_parameters.get("require_summary", False) response = DDGS().text(query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) @@ -34,7 +35,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=prompt), - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[ + SystemPromptMessage(content=prompt), + ], + stop=[], + ) return summary.message.content diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py index 9822b37cf0231d..396ce21b183afc 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py @@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "to": tool_parameters.get('translate_to'), + "keywords": tool_parameters.get("query"), + "to": tool_parameters.get("translate_to"), } - response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!') + response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!") return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py index e8ab02f55ea4ed..e82da8ca534b96 100644 --- a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py +++ b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py @@ -8,35 +8,35 @@ class FeishuGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot + invoke tools + API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot """ url = "https://open.feishu.cn/open-apis/bot/v2/hook" - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - msg_type = 'text' - api_url = f'{url}/{hook_key}' + msg_type = "text" + api_url = f"{url}/{hook_key}" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { "msg_type": msg_type, "content": { "text": content, - } + }, } try: @@ -45,6 +45,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py index febb769ff83cc9..04056af53b5f95 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.py @@ -5,4 +5,4 @@ class FeishuBaseProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: GetTenantAccessTokenTool() - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py index be43b43ce47337..4a605fbffeef0b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py @@ -8,45 +8,49 @@ class AddBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}") + f"Failed to add base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py index 639644e7f0e3ea..6b755e2007d7d6 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py @@ -8,28 +8,25 @@ class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - name = tool_parameters.get('name', '') - folder_token = tool_parameters.get('folder_token', '') + name = tool_parameters.get("name", "") + folder_token = tool_parameters.get("folder_token", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "name": name, - "folder_token": folder_token - } + payload = {"name": name, "folder_token": folder_token} try: res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) @@ -38,6 +35,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base, status code: {res.status_code}, response: {res.text}") + f"Failed to create base, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py index e9062e8730f9ac..b05d700113880b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py @@ -8,37 +8,32 @@ class CreateBaseTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - name = tool_parameters.get('name', '') + name = tool_parameters.get("name", "") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table": { - "name": name, - "fields": json.loads(fields) - } - } + payload = {"table": {"name": name, "fields": json.loads(fields)}} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -47,6 +42,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}") + f"Failed to create base table, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py index aa13aad6fac287..862eb2171b9269 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py @@ -8,45 +8,49 @@ class DeleteBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_ids = tool_parameters.get('record_ids', '') + record_ids = tool_parameters.get("record_ids", "") if not record_ids: - return self.create_text_message('Invalid parameter record_ids') + return self.create_text_message("Invalid parameter record_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "records": json.loads(record_ids) - } + payload = {"records": json.loads(record_ids)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py index c4280ebc21eaed..f5121863035313 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py @@ -8,32 +8,30 @@ class DeleteBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_ids = tool_parameters.get('table_ids', '') + table_ids = tool_parameters.get("table_ids", "") if not table_ids: - return self.create_text_message('Invalid parameter table_ids') + return self.create_text_message("Invalid parameter table_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table_ids": json.loads(table_ids) - } + payload = {"table_ids": json.loads(table_ids)} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -42,6 +40,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py index de70f2ed9359dc..f664bbeed08693 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py @@ -8,22 +8,22 @@ class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: @@ -33,6 +33,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get base info, status code: {res.status_code}, response: {res.text}") + f"Failed to get base info, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get base info. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py index 88507bda60090f..2ea61d0068237b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py @@ -8,27 +8,24 @@ class GetTenantAccessTokenTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - app_id = tool_parameters.get('app_id', '') + app_id = tool_parameters.get("app_id", "") if not app_id: - return self.create_text_message('Invalid parameter app_id') + return self.create_text_message("Invalid parameter app_id") - app_secret = tool_parameters.get('app_secret', '') + app_secret = tool_parameters.get("app_secret", "") if not app_secret: - return self.create_text_message('Invalid parameter app_secret') + return self.create_text_message("Invalid parameter app_secret") headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} """ { @@ -45,6 +42,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}") + f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py index 2a4229f137d7fd..e579d02f6967e7 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py @@ -8,31 +8,31 @@ class ListBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - sort_condition = tool_parameters.get('sort_condition', '') - filter_condition = tool_parameters.get('filter_condition', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + sort_condition = tool_parameters.get("sort_condition", "") + filter_condition = tool_parameters.get("filter_condition", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -40,22 +40,26 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] "page_size": page_size, } - payload = { - "automatic_fields": True - } + payload = {"automatic_fields": True} if sort_condition: payload["sort"] = json.loads(sort_condition) if filter_condition: payload["filter"] = json.loads(filter_condition) try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}") + f"Failed to list base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py index 6d82490eb3235f..4ec9a476bc8832 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py @@ -8,25 +8,25 @@ class ListBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -41,6 +41,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py index bb4bd6c3a6c531..fb818f838073fa 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py @@ -8,40 +8,42 @@ class ReadBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: - res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - timeout=30) + res = httpx.get( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}") + f"Failed to read base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py index 6551053ce22535..6d7e33f3ffef7c 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py @@ -8,49 +8,53 @@ class UpdateBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - params=params, json=payload, timeout=30) + res = httpx.put( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}") + f"Failed to update base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py index c4f8f26e2c9a9b..b0a1e393eb8116 100644 --- a/api/core/tools/provider/builtin/feishu_document/feishu_document.py +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -5,11 +5,11 @@ class FeishuDocumentProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 0ff82e621b5722..090a0828e89bbf 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,13 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get('title') - content = tool_parameters.get('content') - folder_token = tool_parameters.get('folder_token') + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py index 16ef90908bda57..83073e08223a13 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -7,11 +7,11 @@ class GetDocumentRawContentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') + document_id = tool_parameters.get("document_id") res = client.get_document_raw_content(document_id) - return self.create_json_message(res) \ No newline at end of file + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py index 97d17bdb04767a..8c0c4a3c97af0f 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -7,13 +7,13 @@ class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - page_size = tool_parameters.get('page_size', 500) - page_token = tool_parameters.get('page_token', '') + document_id = tool_parameters.get("document_id") + page_size = tool_parameters.get("page_size", 500) + page_token = tool_parameters.get("page_token", "") res = client.list_document_block(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py index 914a44dce62483..6061250e48e136 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -7,13 +7,13 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - content = tool_parameters.get('content') - position = tool_parameters.get('position') + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position") res = client.write_document(document_id, content, position) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py index 6d7fed330c545e..7b3adb9293750b 100644 --- a/api/core/tools/provider/builtin/feishu_message/feishu_message.py +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -5,11 +5,11 @@ class FeishuMessageProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py index 74f6866ba3c3a0..1dd315d0e293a0 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -7,14 +7,14 @@ class SendBotMessageTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - receive_id_type = tool_parameters.get('receive_id_type') - receive_id = tool_parameters.get('receive_id') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py index 7159f59ffa5c67..44e70e0a15b64d 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -6,14 +6,14 @@ class SendWebhookMessageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - webhook = tool_parameters.get('webhook') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_webhook_message(webhook, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 24dc35759d8e6d..01455d7206f185 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the ScrapeTool, only scraping title for minimize content - ScrapeTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', - tool_parameters={ - "url": "https://google.com", - "onlyIncludeTags": 'title' - } + ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"} ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index 3b3f78731b3de4..a0e4cdf9332799 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -13,27 +13,24 @@ class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' + self.base_url = base_url or "https://api.firecrawl.dev" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: if not headers: headers = self._prepare_headers() @@ -44,54 +41,54 @@ def _request( return response.json() except requests.exceptions.RequestException as e: if i < retries - 1: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None def scrape_url(self, url: str, **kwargs): - endpoint = f'{self.base_url}/v0/scrape' - data = {'url': url, **kwargs} + endpoint = f"{self.base_url}/v0/scrape" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response def search(self, query: str, **kwargs): - endpoint = f'{self.base_url}/v0/search' - data = {'query': query, **kwargs} + endpoint = f"{self.base_url}/v0/search" + data = {"query": query, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to perform search after multiple retries") return response def crawl_url( - self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f'{self.base_url}/v0/crawl' + endpoint = f"{self.base_url}/v0/crawl" headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs} + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response['jobId'] + job_id: str = response["jobId"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - response = self._request('GET', endpoint) + endpoint = f"{self.base_url}/v0/crawl/status/{job_id}" + response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' - response = self._request('DELETE', endpoint) + endpoint = f"{self.base_url}/v0/crawl/cancel/{job_id}" + response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") return response @@ -99,9 +96,9 @@ def cancel_crawl_job(self, job_id: str): def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) - if status['status'] == 'completed': + if status["status"] == "completed": return status - elif status['status'] == 'failed': + elif status["status"] == "failed": raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) @@ -109,7 +106,7 @@ def _monitor_job_status(self, job_id: str, poll_interval: int): def get_array_params(tool_parameters: dict[str, Any], key): param = tool_parameters.get(key) if param: - return param.split(',') + return param.split(",") def get_json_params(tool_parameters: dict[str, Any], key): diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 08c40a4064c511..94717cbbfbefc0 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -11,38 +11,36 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe the crawlerOptions and pageOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) crawlerOptions = {} pageOptions = {} - wait_for_results = tool_parameters.get('wait_for_results', True) - - crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') - crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') - crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) - crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') - crawlerOptions['mode'] = tool_parameters.get('mode') - crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) - crawlerOptions['limit'] = tool_parameters.get('limit', 5) - crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) - crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) - - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + wait_for_results = tool_parameters.get("wait_for_results", True) + + crawlerOptions["excludes"] = get_array_params(tool_parameters, "excludes") + crawlerOptions["includes"] = get_array_params(tool_parameters, "includes") + crawlerOptions["returnOnlyUrls"] = tool_parameters.get("returnOnlyUrls", False) + crawlerOptions["maxDepth"] = tool_parameters.get("maxDepth") + crawlerOptions["mode"] = tool_parameters.get("mode") + crawlerOptions["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + crawlerOptions["limit"] = tool_parameters.get("limit", 5) + crawlerOptions["allowBackwardCrawling"] = tool_parameters.get("allowBackwardCrawling", False) + crawlerOptions["allowExternalContentLinks"] = tool_parameters.get("allowExternalContentLinks", False) + + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) crawl_result = app.crawl_url( - url=tool_parameters['url'], - wait=wait_for_results, - crawlerOptions=crawlerOptions, - pageOptions=pageOptions + url=tool_parameters["url"], wait=wait_for_results, crawlerOptions=crawlerOptions, pageOptions=pageOptions ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py index fa6c1f87ee2c42..0d2486c7ca4426 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -7,14 +7,15 @@ class CrawlJobTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - operation = tool_parameters.get('operation', 'get') - if operation == 'get': - result = app.check_crawl_status(job_id=tool_parameters['job_id']) - elif operation == 'cancel': - result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + operation = tool_parameters.get("operation", "get") + if operation == "get": + result = app.check_crawl_status(job_id=tool_parameters["job_id"]) + elif operation == "cancel": + result = app.cancel_crawl_job(job_id=tool_parameters["job_id"]) else: - raise ValueError(f'Invalid operation: {operation}') + raise ValueError(f"Invalid operation: {operation}") return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 91412da548a0b6..962570bf737a57 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ the pageOptions and extractorOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} extractorOptions = {} - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) - - extractorOptions['mode'] = tool_parameters.get('mode', '') - extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') - extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') - - crawl_result = app.scrape_url(url=tool_parameters['url'], - pageOptions=pageOptions, - extractorOptions=extractorOptions) + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) + + extractorOptions["mode"] = tool_parameters.get("mode", "") + extractorOptions["extractionPrompt"] = tool_parameters.get("extractionPrompt", "") + extractorOptions["extractionSchema"] = get_json_params(tool_parameters, "extractionSchema") + + crawl_result = app.scrape_url( + url=tool_parameters["url"], pageOptions=pageOptions, extractorOptions=extractorOptions + ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py index e2b2ac6b4dddb6..f077e7d8ea2835 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/search.py @@ -11,18 +11,17 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe the pageOptions and searchOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/search """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - searchOptions = {'limit': tool_parameters.get('limit')} + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["fetchPageContent"] = tool_parameters.get("fetchPageContent", True) + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + searchOptions = {"limit": tool_parameters.get("limit")} search_result = app.search( - query=tool_parameters['keyword'], - pageOptions=pageOptions, - searchOptions=searchOptions + query=tool_parameters["keyword"], pageOptions=pageOptions, searchOptions=searchOptions ) return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b0d2b..a3e50da0012efd 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -9,17 +9,19 @@ class GaodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'api_key' not in credentials or not credentials.get('api_key'): + if "api_key" not in credentials or not credentials.get("api_key"): raise ToolProviderCredentialValidationError("Gaode API key is required.") try: - response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" - "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'), - apikey=credentials.get('api_key'))) - if response.status_code == 200 and (response.json()).get('info') == 'OK': + response = requests.get( + url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" "".format( + address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key") + ) + ) + if response.status_code == 200 and (response.json()).get("info") == "OK": pass else: - raise ToolProviderCredentialValidationError((response.json()).get('info')) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index efd11cedce4238..843504eefd962d 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -8,50 +8,57 @@ class GaodeRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - city = tool_parameters.get('city', '') + city = tool_parameters.get("city", "") if not city: - return self.create_text_message('Please tell me your city') + return self.create_text_message("Please tell me your city") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Gaode API key is required.") try: s = requests.session() - api_domain = 'https://restapi.amap.com/v3' - city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"}, - url="{url}/config/district?keywords={keywords}" - "&subdistrict=0&extensions=base&key={apikey}" - "".format(url=api_domain, keywords=city, - apikey=self.runtime.credentials.get('api_key'))) + api_domain = "https://restapi.amap.com/v3" + city_response = s.request( + method="GET", + headers={"Content-Type": "application/json; charset=utf-8"}, + url="{url}/config/district?keywords={keywords}" "&subdistrict=0&extensions=base&key={apikey}" "".format( + url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key") + ), + ) City_data = city_response.json() - if city_response.status_code == 200 and City_data.get('info') == 'OK': - if len(City_data.get('districts')) > 0: - CityCode = City_data['districts'][0]['adcode'] - weatherInfo_response = s.request(method='GET', - url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" - "".format(url=api_domain, citycode=CityCode, - apikey=self.runtime.credentials.get('api_key'))) + if city_response.status_code == 200 and City_data.get("info") == "OK": + if len(City_data.get("districts")) > 0: + CityCode = City_data["districts"][0]["adcode"] + weatherInfo_response = s.request( + method="GET", + url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" + "".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), + ) weatherInfo_data = weatherInfo_response.json() - if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': + if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": contents = [] - if len(weatherInfo_data.get('forecasts')) > 0: - for item in weatherInfo_data['forecasts'][0]['casts']: + if len(weatherInfo_data.get("forecasts")) > 0: + for item in weatherInfo_data["forecasts"][0]["casts"]: content = {} - content['date'] = item.get('date') - content['week'] = item.get('week') - content['dayweather'] = item.get('dayweather') - content['daytemp_float'] = item.get('daytemp_float') - content['daywind'] = item.get('daywind') - content['nightweather'] = item.get('nightweather') - content['nighttemp_float'] = item.get('nighttemp_float') + content["date"] = item.get("date") + content["week"] = item.get("week") + content["dayweather"] = item.get("dayweather") + content["daytemp_float"] = item.get("daytemp_float") + content["daywind"] = item.get("daywind") + content["nightweather"] = item.get("nightweather") + content["nighttemp_float"] = item.get("nighttemp_float") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) s.close() - return self.create_text_message(f'No weather information for {city} was found.') + return self.create_text_message(f"No weather information for {city} was found.") except Exception as e: return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py index c81d5fa333cd5d..bbd07d120fd0ea 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the text2image tool - Text2ImageTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "prompt": "A fire egg", "response_format": "url", "style": "photorealism", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py index e28c57649cac4c..0e95a5f654505f 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -8,18 +8,16 @@ logger = logging.getLogger(__name__) + class GetImgAIApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.getimg.ai/v1' + self.base_url = base_url or "https://api.getimg.ai/v1" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return headers def _request( @@ -38,22 +36,20 @@ def _request( return response.json() except requests.exceptions.RequestException as e: if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None - def text2image( - self, mode: str, **kwargs - ): - data = kwargs['params'] - if not data.get('prompt'): + def text2image(self, mode: str, **kwargs): + data = kwargs["params"] + if not data.get("prompt"): raise ValueError("Prompt is required") - endpoint = f'{self.base_url}/{mode}/text-to-image' + endpoint = f"{self.base_url}/{mode}/text-to-image" headers = self._prepare_headers() logger.debug(f"Send request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate getimg.ai after multiple retries") return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py index dad7314479a89d..c556749552c8ef 100644 --- a/api/core/tools/provider/builtin/getimgai/tools/text2image.py +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -7,28 +7,28 @@ class Text2ImageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp( + api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"] + ) options = { - 'style': tool_parameters.get('style'), - 'prompt': tool_parameters.get('prompt'), - 'aspect_ratio': tool_parameters.get('aspect_ratio'), - 'output_format': tool_parameters.get('output_format', 'jpeg'), - 'response_format': tool_parameters.get('response_format', 'url'), - 'width': tool_parameters.get('width'), - 'height': tool_parameters.get('height'), - 'steps': tool_parameters.get('steps'), - 'negative_prompt': tool_parameters.get('negative_prompt'), - 'prompt_2': tool_parameters.get('prompt_2'), + "style": tool_parameters.get("style"), + "prompt": tool_parameters.get("prompt"), + "aspect_ratio": tool_parameters.get("aspect_ratio"), + "output_format": tool_parameters.get("output_format", "jpeg"), + "response_format": tool_parameters.get("response_format", "url"), + "width": tool_parameters.get("width"), + "height": tool_parameters.get("height"), + "steps": tool_parameters.get("steps"), + "negative_prompt": tool_parameters.get("negative_prompt"), + "prompt_2": tool_parameters.get("prompt_2"), } options = {k: v for k, v in options.items() if v} - text2image_result = app.text2image( - mode=tool_parameters.get('mode', 'essential-v2'), - params=options, - wait=True - ) + text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True) if not isinstance(text2image_result, str): text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index b19f0896f8ed79..87a34ac3e806ea 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -7,25 +7,25 @@ class GithubProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Github API Access Tokens is required.") - if 'api_version' not in credentials or not credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in credentials or not credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = credentials.get('api_version') + api_version = credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } response = requests.get( - url="https://api.github.com/search/users?q={account}".format(account='charli117'), - headers=headers) + url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers + ) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index 305bf08ce8e9d7..3eab8bf8dc64ae 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -10,53 +10,61 @@ class GithubRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - top_n = tool_parameters.get('top_n', 5) - query = tool_parameters.get('query', '') + top_n = tool_parameters.get("top_n", 5) + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input symbol') + return self.create_text_message("Please input symbol") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Github API Access Tokens is required.") - if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = self.runtime.credentials.get('api_version') + api_version = self.runtime.credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } s = requests.session() - api_domain = 'https://api.github.com' - response = s.request(method='GET', headers=headers, - url=f"{api_domain}/search/repositories?" - f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") + api_domain = "https://api.github.com" + response = s.request( + method="GET", + headers=headers, + url=f"{api_domain}/search/repositories?" f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc", + ) response_data = response.json() - if response.status_code == 200 and isinstance(response_data.get('items'), list): + if response.status_code == 200 and isinstance(response_data.get("items"), list): contents = [] - if len(response_data.get('items')) > 0: - for item in response_data.get('items'): + if len(response_data.get("items")) > 0: + for item in response_data.get("items"): content = {} - updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") - content['owner'] = item['owner']['login'] - content['name'] = item['name'] - content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description'] - content['url'] = item['html_url'] - content['star'] = item['watchers'] - content['forks'] = item['forks'] - content['updated'] = updated_at_object.strftime("%Y-%m-%d") + updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ") + content["owner"] = item["owner"]["login"] + content["name"] = item["name"] + content["description"] = ( + item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + ) + content["url"] = item["html_url"] + content["star"] = item["watchers"] + content["forks"] = item["forks"] + content["updated"] = updated_at_object.strftime("%Y-%m-%d") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) else: - return self.create_text_message(f'No items related to {query} were found.') + return self.create_text_message(f"No items related to {query} were found.") else: - return self.create_text_message((response.json()).get('message')) + return self.create_text_message((response.json()).get("message")) except Exception as e: return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index 0c13ec662a4f98..9bd4a0bd52ea64 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -9,13 +9,13 @@ class GitlabProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") - - if 'site_url' not in credentials or not credentials.get('site_url'): - site_url = 'https://gitlab.com' + + if "site_url" not in credentials or not credentials.get("site_url"): + site_url = "https://gitlab.com" else: - site_url = credentials.get('site_url') + site_url = credentials.get("site_url") try: headers = { @@ -23,12 +23,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "Authorization": f"Bearer {credentials.get('access_tokens')}", } - response = requests.get( - url= f"{site_url}/api/v4/user", - headers=headers) + response = requests.get(url=f"{site_url}/api/v4/user", headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 0824eb3a26346a..dceb37db493ee1 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -9,39 +9,47 @@ class GitlabCommitsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - employee = tool_parameters.get('employee', '') - start_time = tool_parameters.get('start_time', '') - end_time = tool_parameters.get('end_time', '') - change_type = tool_parameters.get('change_type', 'all') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + employee = tool_parameters.get("employee", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + change_type = tool_parameters.get("change_type", "all") if not project: - return self.create_text_message('Project is required') + return self.create_text_message("Project is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get commit content result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) return [self.create_json_message(item) for item in result] - - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + + def fetch( + self, + user_id: str, + site_url: str, + access_token: str, + project: str, + employee: str = None, + start_time: str = "", + end_time: str = "", + change_type: str = "", + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] @@ -53,59 +61,66 @@ def fetch(self,user_id: str, site_url: str, access_token: str, project: str, emp response.raise_for_status() projects = response.json() - filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + filtered_projects = [p for p in projects if project == "*" or p["name"] == project] for project in filtered_projects: - project_id = project['id'] - project_name = project['name'] + project_id = project["id"] + project_name = project["name"] print(f"Project: {project_name}") # Get all of project commits commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = { - 'since': start_time, - 'until': end_time - } + params = {"since": start_time, "until": end_time} if employee: - params['author'] = employee + params["author"] = employee commits_response = requests.get(commits_url, headers=headers, params=params) commits_response.raise_for_status() commits = commits_response.json() for commit in commits: - commit_sha = commit['id'] - author_name = commit['author_name'] + commit_sha = commit["id"] + author_name = commit["author_name"] diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" diff_response = requests.get(diff_url, headers=headers) diff_response.raise_for_status() diffs = diff_response.json() - + for diff in diffs: # Calculate code lines of changed - added_lines = diff['diff'].count('\n+') - removed_lines = diff['diff'].count('\n-') + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") total_changes = added_lines + removed_lines if change_type == "new": if added_lines > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code - }) + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code} + ) else: if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code_escaped - }) + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 7fa1d0d1124bab..4a42b0fd7306c9 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -7,32 +7,29 @@ class GitlabFilesTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - branch = tool_parameters.get('branch', '') - path = tool_parameters.get('path', '') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + branch = tool_parameters.get("branch", "") + path = tool_parameters.get("path", "") if not project: - return self.create_text_message('Project is required') + return self.create_text_message("Project is required") if not branch: - return self.create_text_message('Branch is required') + return self.create_text_message("Branch is required") if not path: - return self.create_text_message('Path is required') + return self.create_text_message("Path is required") - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get project ID from project name project_id = self.get_project_id(site_url, access_token, project) if not project_id: @@ -42,9 +39,9 @@ def _invoke(self, result = self.fetch(user_id, project_id, site_url, access_token, branch, path) return [self.create_json_message(item) for item in result] - + def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split('/', 1) + parts = path.split("/", 1) if len(parts) < 2: return None, None return parts[0], parts[1] @@ -57,13 +54,15 @@ def get_project_id(self, site_url: str, access_token: str, project_name: str) -> response.raise_for_status() projects = response.json() for project in projects: - if project['name'] == project_name: - return project['id'] + if project["name"] == project_name: + return project["id"] except requests.RequestException as e: print(f"Error fetching project ID from GitLab: {e}") return None - - def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: + + def fetch( + self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] @@ -76,20 +75,16 @@ def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, items = response.json() for item in items: - item_path = item['path'] - if item['type'] == 'tree': # It's a directory + item_path = item["path"] + if item["type"] == "tree": # It's a directory results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) else: # It's a file file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" file_response = requests.get(file_url, headers=headers) file_response.raise_for_status() file_content = file_response.text - results.append({ - "path": item_path, - "branch": branch, - "content": file_content - }) + results.append({"path": item_path, "branch": branch, "content": file_content}) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 8f4b9a4a4e9784..6b5395f9d3e5b8 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 09d0326fb4a0b7..a9f65925d86f94 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -9,7 +9,6 @@ class GoogleSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledge_graph" in response: @@ -17,25 +16,23 @@ def _parse_response(self, response: dict) -> dict: result["description"] = response["knowledge_graph"].get("description", "") if "organic_results" in response: result["organic_results"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic_results"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: params = { - "api_key": self.runtime.credentials['serpapi_api_key'], - "q": tool_parameters['query'], + "api_key": self.runtime.credentials["serpapi_api_key"], + "q": tool_parameters["query"], "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en" + "hl": "en", } response = requests.get(url=SERP_API_URL, params=params) response.raise_for_status() diff --git a/api/core/tools/provider/builtin/google_translate/google_translate.py b/api/core/tools/provider/builtin/google_translate/google_translate.py index f6e1d65834798b..ea53aa4eeb906f 100644 --- a/api/core/tools/provider/builtin/google_translate/google_translate.py +++ b/api/core/tools/provider/builtin/google_translate/google_translate.py @@ -8,10 +8,6 @@ class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - GoogleTranslate().invoke(user_id='', - tool_parameters={ - "content": "这是一段测试文本", - "dest": "en" - }) + GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google_translate/tools/translate.py b/api/core/tools/provider/builtin/google_translate/tools/translate.py index 4314182b06dbbc..5d57b5fabfb94f 100644 --- a/api/core/tools/provider/builtin/google_translate/tools/translate.py +++ b/api/core/tools/provider/builtin/google_translate/tools/translate.py @@ -7,46 +7,40 @@ class GoogleTranslate(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - dest = tool_parameters.get('dest', '') + dest = tool_parameters.get("dest", "") if not dest: - return self.create_text_message('Invalid parameter destination language') + return self.create_text_message("Invalid parameter destination language") try: result = self._translate(content, dest) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Translation service error, please check the network') + return self.create_text_message("Translation service error, please check the network") def _translate(self, content: str, dest: str) -> str: try: url = "https://translate.googleapis.com/translate_a/single" - params = { - "client": "gtx", - "sl": "auto", - "tl": dest, - "dt": "t", - "q": content - } + params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content} headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - response_json = requests.get( - url, params=params, headers=headers).json() + response_json = requests.get(url, params=params, headers=headers).json() result = response_json[0] - translated_text = ''.join([item[0] for item in result if item[0]]) + translated_text = "".join([item[0] for item in result if item[0]]) return str(translated_text) except Exception as e: return str(e) diff --git a/api/core/tools/provider/builtin/hap/hap.py b/api/core/tools/provider/builtin/hap/hap.py index e0a48e05a5ef8c..cbdf9504659568 100644 --- a/api/core/tools/provider/builtin/hap/hap.py +++ b/api/core/tools/provider/builtin/hap/hap.py @@ -5,4 +5,4 @@ class HapProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index 0e101dc67daa13..f2288ed81c58e7 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -8,41 +8,40 @@ class AddWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Worksheet ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/addRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to add the new record. {res_json['error_msg']}") return self.create_text_message(f"New record added successfully. The record ID is {res_json['data']}.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index ba25952c9f4ea6..1df5f6d5cf1083 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -7,43 +7,42 @@ class DeleteWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/deleteRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=30) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to delete the record. {res_json['error_msg']}") return self.create_text_message("Successfully deleted the record.") except httpx.RequestError as e: return self.create_text_message(f"Failed to delete the record, request error: {e}") except Exception as e: - return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") \ No newline at end of file + return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 2c46d9dd4e7392..69cf8aa740b16b 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -8,43 +8,42 @@ class GetWorksheetFieldsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Worksheet ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to get the worksheet information. {res_json['error_msg']}") - - fields_json, fields_table = self.get_controls(res_json['data']['controls']) - result_type = tool_parameters.get('result_type', 'table') + + fields_json, fields_table = self.get_controls(res_json["data"]["controls"]) + result_type = tool_parameters.get("result_type", "table") return self.create_text_message( - text=json.dumps(fields_json, ensure_ascii=False) if result_type == 'json' else fields_table + text=json.dumps(fields_json, ensure_ascii=False) if result_type == "json" else fields_table ) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet information, request error: {e}") @@ -88,61 +87,65 @@ def get_field_type_by_id(self, field_type_id: int) -> str: 50: "Text", 51: "Query Record", } - return field_type_map.get(field_type_id, '') + return field_type_map.get(field_type_id, "") def get_controls(self, controls: list) -> dict: fields = [] - fields_list = ['|fieldId|fieldName|fieldType|fieldTypeId|description|options|','|'+'---|'*6] + fields_list = ["|fieldId|fieldName|fieldType|fieldTypeId|description|options|", "|" + "---|" * 6] for control in controls: - if control['type'] in self._get_ignore_types(): + if control["type"] in self._get_ignore_types(): continue - field_type_id = control['type'] - field_type = self.get_field_type_by_id(control['type']) + field_type_id = control["type"] + field_type = self.get_field_type_by_id(control["type"]) if field_type_id == 30: - source_type = control['sourceControl']['type'] + source_type = control["sourceControl"]["type"] if source_type in self._get_ignore_types(): continue else: field_type_id = source_type field_type = self.get_field_type_by_id(source_type) field = { - 'id': control['controlId'], - 'name': control['controlName'], - 'type': field_type, - 'typeId': field_type_id, - 'description': control['remark'].replace('\n', ' ').replace('\t', ' '), - 'options': self._extract_options(control), + "id": control["controlId"], + "name": control["controlName"], + "type": field_type, + "typeId": field_type_id, + "description": control["remark"].replace("\n", " ").replace("\t", " "), + "options": self._extract_options(control), } fields.append(field) - fields_list.append(f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|") + fields_list.append( + f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|" + ) - fields.append({ - 'id': 'ctime', - 'name': 'Created Time', - 'type': self.get_field_type_by_id(16), - 'typeId': 16, - 'description': '', - 'options': [] - }) + fields.append( + { + "id": "ctime", + "name": "Created Time", + "type": self.get_field_type_by_id(16), + "typeId": 16, + "description": "", + "options": [], + } + ) fields_list.append("|ctime|Created Time|Date|16|||") - return fields, '\n'.join(fields_list) + return fields, "\n".join(fields_list) def _extract_options(self, control: dict) -> list: options = [] - if control['type'] in [9, 10, 11]: - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) - elif control['type'] in [28, 36]: - itemnames = control['advancedSetting'].get('itemnames') - if itemnames and itemnames.startswith('[{'): + if control["type"] in [9, 10, 11]: + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) + elif control["type"] in [28, 36]: + itemnames = control["advancedSetting"].get("itemnames") + if itemnames and itemnames.startswith("[{"): try: options = json.loads(itemnames) except json.JSONDecodeError: pass - elif control['type'] == 30: - source_type = control['sourceControl']['type'] + elif control["type"] == 30: + source_type = control["sourceControl"]["type"] if source_type not in self._get_ignore_types(): - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) return options - + def _get_ignore_types(self): - return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} \ No newline at end of file + return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6bf1caa65ec337..6b831f3145b714 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -8,64 +8,66 @@ class GetWorksheetPivotDataTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - x_column_fields = tool_parameters.get('x_column_fields', '') - if not x_column_fields or not x_column_fields.startswith('['): - return self.create_text_message('Invalid parameter Column Fields') - y_row_fields = tool_parameters.get('y_row_fields', '') - if y_row_fields and not y_row_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Row Fields') + return self.create_text_message("Invalid parameter Worksheet ID") + x_column_fields = tool_parameters.get("x_column_fields", "") + if not x_column_fields or not x_column_fields.startswith("["): + return self.create_text_message("Invalid parameter Column Fields") + y_row_fields = tool_parameters.get("y_row_fields", "") + if y_row_fields and not y_row_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Row Fields") elif not y_row_fields: - y_row_fields = '[]' - value_fields = tool_parameters.get('value_fields', '') - if not value_fields or not value_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Value Fields') - - host = tool_parameters.get('host', '') + y_row_fields = "[]" + value_fields = tool_parameters.get("value_fields", "") + if not value_fields or not value_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Value Fields") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/report/getPivotData" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "options": {"showTotal": True}} try: x_column_fields = json.loads(x_column_fields) - payload['columns'] = x_column_fields + payload["columns"] = x_column_fields y_row_fields = json.loads(y_row_fields) - if y_row_fields: payload['rows'] = y_row_fields + if y_row_fields: + payload["rows"] = y_row_fields value_fields = json.loads(value_fields) - payload['values'] = value_fields - sort_fields = tool_parameters.get('sort_fields', '') - if not sort_fields: sort_fields = '[]' + payload["values"] = value_fields + sort_fields = tool_parameters.get("sort_fields", "") + if not sort_fields: + sort_fields = "[]" sort_fields = json.loads(sort_fields) - if sort_fields: payload['options']['sort'] = sort_fields + if sort_fields: + payload["options"]["sort"] = sort_fields res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('status') != 1: + if res_json.get("status") != 1: return self.create_text_message(f"Failed to get the worksheet pivot data. {res_json['msg']}") - - pivot_json = self.generate_pivot_json(res_json['data']) - pivot_table = self.generate_pivot_table(res_json['data']) - result_type = tool_parameters.get('result_type', '') - text = pivot_table if result_type == 'table' else json.dumps(pivot_json, ensure_ascii=False) + + pivot_json = self.generate_pivot_json(res_json["data"]) + pivot_table = self.generate_pivot_table(res_json["data"]) + result_type = tool_parameters.get("result_type", "") + text = pivot_table if result_type == "table" else json.dumps(pivot_json, ensure_ascii=False) return self.create_text_message(text) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet pivot data, request error: {e}") @@ -75,27 +77,31 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(f"Failed to get the worksheet pivot data, unexpected error: {e}") def generate_pivot_table(self, data: dict[str, Any]) -> str: - columns = data['metadata']['columns'] - rows = data['metadata']['rows'] - values = data['metadata']['values'] + columns = data["metadata"]["columns"] + rows = data["metadata"]["rows"] + values = data["metadata"]["values"] - rows_data = data['data'] + rows_data = data["data"] - header = ([row['displayName'] for row in rows] if rows else []) + [column['displayName'] for column in columns] + [value['displayName'] for value in values] - line = (['---'] * len(rows) if rows else []) + ['---'] * len(columns) + ['--:'] * len(values) + header = ( + ([row["displayName"] for row in rows] if rows else []) + + [column["displayName"] for column in columns] + + [value["displayName"] for value in values] + ) + line = (["---"] * len(rows) if rows else []) + ["---"] * len(columns) + ["--:"] * len(values) table = [header, line] for row in rows_data: - row_data = [self.replace_pipe(row['rows'][r['controlId']]) for r in rows] if rows else [] - row_data.extend([self.replace_pipe(row['columns'][column['controlId']]) for column in columns]) - row_data.extend([self.replace_pipe(str(row['values'][value['controlId']])) for value in values]) + row_data = [self.replace_pipe(row["rows"][r["controlId"]]) for r in rows] if rows else [] + row_data.extend([self.replace_pipe(row["columns"][column["controlId"]]) for column in columns]) + row_data.extend([self.replace_pipe(str(row["values"][value["controlId"]])) for value in values]) table.append(row_data) - return '\n'.join([('|'+'|'.join(row) +'|') for row in table]) - + return "\n".join([("|" + "|".join(row) + "|") for row in table]) + def replace_pipe(self, text: str) -> str: - return text.replace('|', '▏').replace('\n', ' ') - + return text.replace("|", "▏").replace("\n", " ") + def generate_pivot_json(self, data: dict[str, Any]) -> dict: fields = { "x-axis": [ @@ -103,13 +109,14 @@ def generate_pivot_json(self, data: dict[str, Any]) -> dict: for column in data["metadata"]["columns"] ], "y-axis": [ - {"fieldId": row["controlId"], "fieldName": row["displayName"]} - for row in data["metadata"]["rows"] - ] if data["metadata"]["rows"] else [], + {"fieldId": row["controlId"], "fieldName": row["displayName"]} for row in data["metadata"]["rows"] + ] + if data["metadata"]["rows"] + else [], "values": [ {"fieldId": value["controlId"], "fieldName": value["displayName"]} for value in data["metadata"]["values"] - ] + ], } # fields = ([ # {"fieldId": row["controlId"], "fieldName": row["displayName"]} @@ -127,4 +134,4 @@ def generate_pivot_json(self, data: dict[str, Any]) -> dict: row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) - return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} \ No newline at end of file + return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index dddc041cc19fb3..7e9f70f8e5feb3 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -9,152 +9,173 @@ class ListWorksheetRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') + return self.create_text_message("Invalid parameter App Key") - sign = tool_parameters.get('sign', '') + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') + return self.create_text_message("Invalid parameter Sign") - worksheet_id = tool_parameters.get('worksheet_id', '') + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') + return self.create_text_message("Invalid parameter Worksheet ID") - host = tool_parameters.get('host', '') + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" - + url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} - field_ids = tool_parameters.get('field_ids', '') + field_ids = tool_parameters.get("field_ids", "") try: res = httpx.post(url_fields, headers=headers, json=payload, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the worksheet information. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to get the worksheet information. {}".format(res_json["error_msg"]) + ) else: - worksheet_name = res_json['data']['name'] - fields, schema, table_header = self.get_schema(res_json['data']['controls'], field_ids) + worksheet_name = res_json["data"]["name"] + fields, schema, table_header = self.get_schema(res_json["data"]["controls"], field_ids) else: return self.create_text_message( - f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}") + f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to get the worksheet information, something went wrong: {}".format(e)) + return self.create_text_message( + "Failed to get the worksheet information, something went wrong: {}".format(e) + ) if field_ids: - payload['controls'] = [v.strip() for v in field_ids.split(',')] if field_ids else [] - filters = tool_parameters.get('filters', '') + payload["controls"] = [v.strip() for v in field_ids.split(",")] if field_ids else [] + filters = tool_parameters.get("filters", "") if filters: - payload['filters'] = json.loads(filters) - sort_id = tool_parameters.get('sort_id', '') - sort_is_asc = tool_parameters.get('sort_is_asc', False) + payload["filters"] = json.loads(filters) + sort_id = tool_parameters.get("sort_id", "") + sort_is_asc = tool_parameters.get("sort_is_asc", False) if sort_id: - payload['sortId'] = sort_id - payload['isAsc'] = sort_is_asc - limit = tool_parameters.get('limit', 50) - payload['pageSize'] = limit - page_index = tool_parameters.get('page_index', 1) - payload['pageIndex'] = page_index - payload['useControlId'] = True - payload['listType'] = 1 + payload["sortId"] = sort_id + payload["isAsc"] = sort_is_asc + limit = tool_parameters.get("limit", 50) + payload["pageSize"] = limit + page_index = tool_parameters.get("page_index", 1) + payload["pageIndex"] = page_index + payload["useControlId"] = True + payload["listType"] = 1 url = f"{host}/v2/open/worksheet/getFilterRows" try: res = httpx.post(url, headers=headers, json=payload, timeout=90) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the records. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message("Failed to get the records. {}".format(res_json["error_msg"])) else: result = { "fields": fields, "rows": [], "total": res_json.get("data", {}).get("total"), - "payload": {key: payload[key] for key in ['worksheetId', 'controls', 'filters', 'sortId', 'isAsc', 'pageSize', 'pageIndex'] if key in payload} + "payload": { + key: payload[key] + for key in [ + "worksheetId", + "controls", + "filters", + "sortId", + "isAsc", + "pageSize", + "pageIndex", + ] + if key in payload + }, } rows = res_json.get("data", {}).get("rows", []) - result_type = tool_parameters.get('result_type', '') - if not result_type: result_type = 'table' - if result_type == 'json': + result_type = tool_parameters.get("result_type", "") + if not result_type: + result_type = "table" + if result_type == "json": for row in rows: - result['rows'].append(self.get_row_field_value(row, schema)) + result["rows"].append(self.get_row_field_value(row, schema)) return self.create_text_message(json.dumps(result, ensure_ascii=False)) else: result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." - if result['total'] > 0: + if result["total"] > 0: result_text += f" The following are {result['total'] if result['total'] < limit else limit} pieces of data presented in a table format:\n\n{table_header}" for row in rows: result_values = [] for f in fields: - result_values.append(self.handle_value_type(row[f['fieldId']], schema[f['fieldId']])) - result_text += '\n|'+'|'.join(result_values)+'|' + result_values.append( + self.handle_value_type(row[f["fieldId"]], schema[f["fieldId"]]) + ) + result_text += "\n|" + "|".join(result_values) + "|" return self.create_text_message(result_text) else: return self.create_text_message( - f"Failed to get the records, status code: {res.status_code}, response: {res.text}") + f"Failed to get the records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get the records, something went wrong: {}".format(e)) - def get_row_field_value(self, row: dict, schema: dict): row_value = {"rowid": row["rowid"]} for field in schema: row_value[field] = self.handle_value_type(row[field], schema[field]) return row_value - - def get_schema(self, controls: list, fieldids: str): - allow_fields = {v.strip() for v in fieldids.split(',')} if fieldids else set() + def get_schema(self, controls: list, fieldids: str): + allow_fields = {v.strip() for v in fieldids.split(",")} if fieldids else set() fields = [] schema = {} field_names = [] for control in controls: control_type_id = self.get_real_type_id(control) - if (control_type_id in self._get_ignore_types()) or (allow_fields and not control['controlId'] in allow_fields): + if (control_type_id in self._get_ignore_types()) or ( + allow_fields and not control["controlId"] in allow_fields + ): continue else: - fields.append({'fieldId': control['controlId'], 'fieldName': control['controlName']}) - schema[control['controlId']] = {'typeId': control_type_id, 'options': self.set_option(control)} - field_names.append(control['controlName']) - if (not allow_fields or ('ctime' in allow_fields)): - fields.append({'fieldId': 'ctime', 'fieldName': 'Created Time'}) - schema['ctime'] = {'typeId': 16, 'options': {}} + fields.append({"fieldId": control["controlId"], "fieldName": control["controlName"]}) + schema[control["controlId"]] = {"typeId": control_type_id, "options": self.set_option(control)} + field_names.append(control["controlName"]) + if not allow_fields or ("ctime" in allow_fields): + fields.append({"fieldId": "ctime", "fieldName": "Created Time"}) + schema["ctime"] = {"typeId": 16, "options": {}} field_names.append("Created Time") - fields.append({'fieldId':'rowid', 'fieldName': 'Record Row ID'}) - schema['rowid'] = {'typeId': 2, 'options': {}} + fields.append({"fieldId": "rowid", "fieldName": "Record Row ID"}) + schema["rowid"] = {"typeId": 2, "options": {}} field_names.append("Record Row ID") - return fields, schema, '|'+'|'.join(field_names)+'|\n|'+'---|'*len(field_names) - + return fields, schema, "|" + "|".join(field_names) + "|\n|" + "---|" * len(field_names) + def get_real_type_id(self, control: dict) -> int: - return control['sourceControlType'] if control['type'] == 30 else control['type'] - + return control["sourceControlType"] if control["type"] == 30 else control["type"] + def set_option(self, control: dict) -> dict: options = {} - if control.get('options'): - options = {option['key']: option['value'] for option in control['options']} - elif control.get('advancedSetting', {}).get('itemnames'): + if control.get("options"): + options = {option["key"]: option["value"] for option in control["options"]} + elif control.get("advancedSetting", {}).get("itemnames"): try: - itemnames = json.loads(control['advancedSetting']['itemnames']) - options = {item['key']: item['value'] for item in itemnames} + itemnames = json.loads(control["advancedSetting"]["itemnames"]) + options = {item["key"]: item["value"] for item in itemnames} except json.JSONDecodeError: pass return options def _get_ignore_types(self): return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} - + def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: @@ -167,33 +188,33 @@ def handle_value_type(self, value, field): value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) - return self.rich_text_to_plain_text(value) if value else '' + return self.rich_text_to_plain_text(value) if value else "" def process_value(self, value): if isinstance(value, str): - if value.startswith("[{\"accountId\""): + if value.startswith('[{"accountId"'): value = json.loads(value) - value = ', '.join([item['fullname'] for item in value]) - elif value.startswith("[{\"departmentId\""): + value = ", ".join([item["fullname"] for item in value]) + elif value.startswith('[{"departmentId"'): value = json.loads(value) - value = '、'.join([item['departmentName'] for item in value]) - elif value.startswith("[{\"organizeId\""): + value = "、".join([item["departmentName"] for item in value]) + elif value.startswith('[{"organizeId"'): value = json.loads(value) - value = '、'.join([item['organizeName'] for item in value]) - elif value.startswith("[{\"file_id\""): - value = '' - elif value == '[]': - value = '' - elif hasattr(value, 'accountId'): - value = value['fullname'] + value = "、".join([item["organizeName"] for item in value]) + elif value.startswith('[{"file_id"'): + value = "" + elif value == "[]": + value = "" + elif hasattr(value, "accountId"): + value = value["fullname"] return value def parse_cascade_or_associated(self, field, value): - if (field['typeId'] == 35 and value.startswith('[')) or (field['typeId'] == 29 and value.startswith('[{')): + if (field["typeId"] == 35 and value.startswith("[")) or (field["typeId"] == 29 and value.startswith("[{")): value = json.loads(value) - value = value[0]['name'] if len(value) > 0 else '' + value = value[0]["name"] if len(value) > 0 else "" else: - value = '' + value = "" return value def parse_location(self, value): @@ -205,5 +226,5 @@ def parse_location(self, value): return value def rich_text_to_plain_text(self, rich_text): - text = re.sub(r'<[^>]+>', '', rich_text) if '<' in rich_text else rich_text - return text.replace("|", "▏").replace("\n", " ") \ No newline at end of file + text = re.sub(r"<[^>]+>", "", rich_text) if "<" in rich_text else rich_text + return text.replace("|", "▏").replace("\n", " ") diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 960cbd10acadb1..b4193f00bfa515 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -8,75 +8,76 @@ class ListWorksheetsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Sign") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v1/open/app/get" - result_type = tool_parameters.get('result_type', '') + result_type = tool_parameters.get("result_type", "") if not result_type: - result_type = 'table' + result_type = "table" - headers = { 'Content-Type': 'application/json' } - params = { "appKey": appkey, "sign": sign, } + headers = {"Content-Type": "application/json"} + params = { + "appKey": appkey, + "sign": sign, + } try: res = httpx.get(url, headers=headers, params=params, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to access the application. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to access the application. {}".format(res_json["error_msg"]) + ) else: - if result_type == 'json': + if result_type == "json": worksheets = [] - for section in res_json['data']['sections']: + for section in res_json["data"]["sections"]: worksheets.extend(self._extract_worksheets(section, result_type)) return self.create_text_message(text=json.dumps(worksheets, ensure_ascii=False)) else: - worksheets = '|worksheetId|worksheetName|description|\n|---|---|---|' - for section in res_json['data']['sections']: + worksheets = "|worksheetId|worksheetName|description|\n|---|---|---|" + for section in res_json["data"]["sections"]: worksheets += self._extract_worksheets(section, result_type) return self.create_text_message(worksheets) else: return self.create_text_message( - f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}") + f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list worksheets, something went wrong: {}".format(e)) def _extract_worksheets(self, section, type): items = [] - tables = '' - for item in section.get('items', []): - if item.get('type') == 0 and (not 'notes' in item or item.get('notes') != 'NO'): - if type == 'json': - filtered_item = { - 'id': item['id'], - 'name': item['name'], - 'notes': item.get('notes', '') - } + tables = "" + for item in section.get("items", []): + if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"): + if type == "json": + filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) else: tables += f"\n|{item['id']}|{item['name']}|{item.get('notes', '')}|" - for child_section in section.get('childSections', []): - if type == 'json': - items.extend(self._extract_worksheets(child_section, 'json')) + for child_section in section.get("childSections", []): + if type == "json": + items.extend(self._extract_worksheets(child_section, "json")) else: - tables += self._extract_worksheets(child_section, 'table') - - return items if type == 'json' else tables \ No newline at end of file + tables += self._extract_worksheets(child_section, "table") + + return items if type == "json" else tables diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 6ca1b98d90b3f1..32abb18f9a796b 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -8,44 +8,43 @@ class UpdateWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Record Row ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/editRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to update the record. {res_json['error_msg']}") return self.create_text_message("Record updated successfully.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index 12e5058cdc92f0..154e15db016dd1 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -10,27 +10,29 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if credentials['api_key'] is None: - credentials['api_key'] = '' + if credentials["api_key"] is None: + credentials["api_key"] = "" else: - result = JinaReaderTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "url": "https://example.com", - }, - )[0] + result = ( + JinaReaderTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "url": "https://example.com", + }, + )[0] + ) message = json.loads(result.message) - if message['code'] != 200: - raise ToolProviderCredentialValidationError(message['message']) + if message["code"] != 200: + raise ToolProviderCredentialValidationError(message["message"]) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - + def _get_tool_labels(self) -> list[ToolLabelEnum]: - return [ - ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY - ] \ No newline at end of file + return [ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY] diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index cee46cee2390e1..0dd55c65291783 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -9,26 +9,25 @@ class JinaReaderTool(BuiltinTool): - _jina_reader_endpoint = 'https://r.jina.ai/' + _jina_reader_endpoint = "https://r.jina.ai/" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - url = tool_parameters['url'] + url = tool_parameters["url"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - request_params = tool_parameters.get('request_params') - if request_params is not None and request_params != '': + request_params = tool_parameters.get("request_params") + if request_params is not None and request_params != "": try: request_params = json.loads(request_params) if not isinstance(request_params, dict): @@ -36,40 +35,40 @@ def _invoke(self, except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"Invalid request_params: {e}") - target_selector = tool_parameters.get('target_selector') - if target_selector is not None and target_selector != '': - headers['X-Target-Selector'] = target_selector + target_selector = tool_parameters.get("target_selector") + if target_selector is not None and target_selector != "": + headers["X-Target-Selector"] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector') - if wait_for_selector is not None and wait_for_selector != '': - headers['X-Wait-For-Selector'] = wait_for_selector + wait_for_selector = tool_parameters.get("wait_for_selector") + if wait_for_selector is not None and wait_for_selector != "": + headers["X-Wait-For-Selector"] = wait_for_selector - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), - max_retries=max_retries + max_retries=max_retries, ) - if tool_parameters.get('summary', False): + if tool_parameters.get("summary", False): return self.create_text_message(self.summary(user_id, response.text)) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index d4a81cd0965142..30af6de7831e59 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -8,44 +8,39 @@ class JinaSearchTool(BuiltinTool): - _jina_search_endpoint = 'https://s.jina.ai/' + _jina_search_endpoint = "https://s.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters['query'] + query = tool_parameters["query"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( - str(URL(self._jina_search_endpoint + query)), - headers=headers, - timeout=(10, 60), - max_retries=max_retries + str(URL(self._jina_search_endpoint + query)), headers=headers, timeout=(10, 60), max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py index 0d018e3ca27af5..06dabcc9c2a74e 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -6,33 +6,29 @@ class JinaTokenizerTool(BuiltinTool): - _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + _jina_tokenizer_endpoint = "https://tokenize.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> ToolInvokeMessage: - content = tool_parameters['content'] - body = { - "content": content - } - - headers = { - 'Content-Type': 'application/json' - } - - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') - - if tool_parameters.get('return_chunks', False): - body['return_chunks'] = True - - if tool_parameters.get('return_tokens', False): - body['return_tokens'] = True - - if tokenizer := tool_parameters.get('tokenizer'): - body['tokenizer'] = tokenizer + content = tool_parameters["content"] + body = {"content": content} + + headers = {"Content-Type": "application/json"} + + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") + + if tool_parameters.get("return_chunks", False): + body["return_chunks"] = True + + if tool_parameters.get("return_tokens", False): + body["return_tokens"] = True + + if tokenizer := tool_parameters.get("tokenizer"): + body["tokenizer"] = tokenizer response = ssrf_proxy.post( self._jina_tokenizer_endpoint, diff --git a/api/core/tools/provider/builtin/json_process/json_process.py b/api/core/tools/provider/builtin/json_process/json_process.py index f6eed3c6282314..10746210b5c652 100644 --- a/api/core/tools/provider/builtin/json_process/json_process.py +++ b/api/core/tools/provider/builtin/json_process/json_process.py @@ -8,10 +8,9 @@ class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - JSONParseTool().invoke(user_id='', - tool_parameters={ - 'content': '{"name": "John", "age": 30, "city": "New York"}', - 'json_filter': '$.name' - }) + JSONParseTool().invoke( + user_id="", + tool_parameters={"content": '{"name": "John", "age": 30, "city": "New York"}', "json_filter": "$.name"}, + ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index 1b49cfe2f300f8..fcab3d71a93cf9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -8,34 +8,35 @@ class JSONDeleteTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the JSON delete tool """ # Get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # Get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to delete JSON content: {str(e)}') + return self.create_text_message(f"Failed to delete JSON content: {str(e)}") def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) - expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ + expr = parse("$." + query.lstrip("$.")) # Ensure query path starts with $ matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 48d1bdcab48885..793c74e5f9df51 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -8,46 +8,49 @@ class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get new value - new_value = tool_parameters.get('new_value', '') + new_value = tool_parameters.get("new_value", "") if not new_value: - return self.create_text_message('Invalid parameter new_value') + return self.create_text_message("Invalid parameter new_value") # get insert position - index = tool_parameters.get('index') + index = tool_parameters.get("index") # get create path - create_path = tool_parameters.get('create_path', False) + create_path = tool_parameters.get("create_path", False) # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to insert JSON content') + return self.create_text_message("Failed to insert JSON content") - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): + def _insert( + self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False + ): try: input_data = json.loads(origin_json) expr = parse(query) @@ -61,13 +64,13 @@ def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decod if not matches and create_path: # create new path - path_parts = query.strip('$').strip('.').split('.') + path_parts = query.strip("$").strip(".").split(".") current = input_data for i, part in enumerate(path_parts): - if '[' in part and ']' in part: + if "[" in part and "]" in part: # process array index - array_name, index = part.split('[') - index = int(index.rstrip(']')) + array_name, index = part.split("[") + index = int(index.rstrip("]")) if array_name not in current: current[array_name] = [] while len(current[array_name]) <= index: diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index ecd39113ae5498..37cae401533190 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -8,29 +8,30 @@ class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get json filter - json_filter = tool_parameters.get('json_filter', '') + json_filter = tool_parameters.get("json_filter", "") if not json_filter: - return self.create_text_message('Invalid parameter json_filter') + return self.create_text_message("Invalid parameter json_filter") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to extract JSON content') + return self.create_text_message("Failed to extract JSON content") # Extract data from JSON content def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index b19198aa938942..383825c2d0b259 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -8,55 +8,60 @@ class JSONReplaceTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get replace value - replace_value = tool_parameters.get('replace_value', '') + replace_value = tool_parameters.get("replace_value", "") if not replace_value: - return self.create_text_message('Invalid parameter replace_value') + return self.create_text_message("Invalid parameter replace_value") # get replace model - replace_model = tool_parameters.get('replace_model', '') + replace_model = tool_parameters.get("replace_model", "") if not replace_model: - return self.create_text_message('Invalid parameter replace_model') + return self.create_text_message("Invalid parameter replace_model") # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: - if replace_model == 'pattern': + if replace_model == "pattern": # get replace pattern - replace_pattern = tool_parameters.get('replace_pattern', '') + replace_pattern = tool_parameters.get("replace_pattern", "") if not replace_pattern: - return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) - elif replace_model == 'key': + return self.create_text_message("Invalid parameter replace_pattern") + result = self._replace_pattern( + content, query, replace_pattern, replace_value, ensure_ascii, value_decode + ) + elif replace_model == "key": result = self._replace_key(content, query, replace_value, ensure_ascii) - elif replace_model == 'value': + elif replace_model == "value": result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to replace JSON content') + return self.create_text_message("Failed to replace JSON content") # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_pattern( + self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -102,7 +107,9 @@ def _replace_key(self, content: str, query: str, replace_value: str, ensure_asci return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_value( + self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index bac6576797f067..50db74dd9ebced 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -13,7 +13,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "source_code": "print('hello world')", "language_id": 71, @@ -21,4 +21,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py index 6031687c03f48b..b8d654ff639575 100644 --- a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py +++ b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py @@ -9,11 +9,13 @@ class ExecuteCodeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ - api_key = self.runtime.credentials['X-RapidAPI-Key'] + api_key = self.runtime.credentials["X-RapidAPI-Key"] url = "https://judge0-ce.p.rapidapi.com/submissions" @@ -22,15 +24,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn headers = { "Content-Type": "application/json", "X-RapidAPI-Key": api_key, - "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com", } payload = { - "language_id": tool_parameters['language_id'], - "source_code": tool_parameters['source_code'], - "stdin": tool_parameters.get('stdin', ''), - "expected_output": tool_parameters.get('expected_output', ''), - "additional_files": tool_parameters.get('additional_files', ''), + "language_id": tool_parameters["language_id"], + "source_code": tool_parameters["source_code"], + "stdin": tool_parameters.get("stdin", ""), + "expected_output": tool_parameters.get("expected_output", ""), + "additional_files": tool_parameters.get("additional_files", ""), } response = post(url, data=json.dumps(payload), headers=headers, params=querystring) @@ -38,22 +40,22 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code != 201: raise Exception(response.text) - token = response.json()['token'] + token = response.json()["token"] url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}" - headers = { - "X-RapidAPI-Key": api_key - } - + headers = {"X-RapidAPI-Key": api_key} + response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() - return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n" - f"stderr: {result.get('stderr', '')}\n" - f"compile_output: {result.get('compile_output', '')}\n" - f"message: {result.get('message', '')}\n" - f"status: {result['status']['description']}\n" - f"time: {result.get('time', '')} seconds\n" - f"memory: {result.get('memory', '')} bytes") + return self.create_text_message( + text=f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes" + ) else: - return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c1686feb..d4b449ec87a18a 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -9,9 +9,9 @@ class MathsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: EvaluateExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'expression': '1+(2+3)*4', + "expression": "1+(2+3)*4", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed69181eaa..0c5b5e41cbe1e1 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,22 +8,23 @@ class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - expression = tool_parameters.get('expression', '').strip() + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = ne.evaluate(expression) result_str = str(result) except Exception as e: - logging.exception(f'Error evaluating expression: {expression}') - return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + logging.exception(f"Error evaluating expression: {expression}") + return self.create_text_message(f"Invalid expression: {expression}, error: {str(e)}") + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/nominatim/nominatim.py b/api/core/tools/provider/builtin/nominatim/nominatim.py index b6f29b5feb4e44..5a24bed7507eb6 100644 --- a/api/core/tools/provider/builtin/nominatim/nominatim.py +++ b/api/core/tools/provider/builtin/nominatim/nominatim.py @@ -8,16 +8,20 @@ class NominatimProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NominatimSearchTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'query': 'London', - 'limit': 1, - }, + result = ( + NominatimSearchTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "query": "London", + "limit": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py index e21ce14f542161..ffa8ad0fcc02e0 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py @@ -8,40 +8,33 @@ class NominatimLookupTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - osm_ids = tool_parameters.get('osm_ids', '') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + osm_ids = tool_parameters.get("osm_ids", "") + if not osm_ids: - return self.create_text_message('Please provide OSM IDs') + return self.create_text_message("Please provide OSM IDs") + + params = {"osm_ids": osm_ids, "format": "json", "addressdetails": 1} - params = { - 'osm_ids': osm_ids, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'lookup', params) + return self._make_request(user_id, "lookup", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py index 438d5219e97887..f46691e1a3ebb4 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py @@ -8,42 +8,34 @@ class NominatimReverseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - lat = tool_parameters.get('lat') - lon = tool_parameters.get('lon') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + lat = tool_parameters.get("lat") + lon = tool_parameters.get("lon") + if lat is None or lon is None: - return self.create_text_message('Please provide both latitude and longitude') + return self.create_text_message("Please provide both latitude and longitude") + + params = {"lat": lat, "lon": lon, "format": "json", "addressdetails": 1} - params = { - 'lat': lat, - 'lon': lon, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'reverse', params) + return self._make_request(user_id, "reverse", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py index 983cbc0e346577..34851d86dcaa5f 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py @@ -8,42 +8,34 @@ class NominatimSearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query', '') - limit = tool_parameters.get('limit', 10) - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query", "") + limit = tool_parameters.get("limit", 10) + if not query: - return self.create_text_message('Please input a search query') + return self.create_text_message("Please input a search query") + + params = {"q": query, "format": "json", "limit": limit, "addressdetails": 1} - params = { - 'q': query, - 'format': 'json', - 'limit': limit, - 'addressdetails': 1 - } - - return self._make_request(user_id, 'search', params) + return self._make_request(user_id, "search", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index b753be47917559..762e158459cc2a 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -12,10 +12,10 @@ def _extract_loras(self, loras_str: str): if not loras_str: return [] - loras_ori_list = lora_str.strip().split(';') + loras_ori_list = lora_str.strip().split(";") result_list = [] for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') + lora_info = lora_str.strip().split(",") lora = Txt2ImgV3LoRA( model_name=lora_info[0].strip(), strength=float(lora_info[1]), @@ -28,43 +28,39 @@ def _extract_embeddings(self, embeddings_str: str): if not embeddings_str: return [] - embeddings_ori_list = embeddings_str.strip().split(';') + embeddings_ori_list = embeddings_str.strip().split(";") result_list = [] for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) + embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip()) result_list.append(embedding) return result_list def _extract_hires_fix(self, hires_fix_str: str): - hires_fix_info = hires_fix_str.strip().split(',') - if 'upscaler' in hires_fix_info: + hires_fix_info = hires_fix_str.strip().split(",") + if "upscaler" in hires_fix_info: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() + upscaler=hires_fix_info[3].strip(), ) else: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) + strength=float(hires_fix_info[2]), ) return hires_fix def _extract_refiner(self, switch_at: str): - refiner = Txt2ImgV3Refiner( - switch_at=float(switch_at) - ) + refiner = Txt2ImgV3Refiner(switch_at=float(switch_at)) return refiner def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: """ - is hit nsfw + is hit nsfw """ if image.nsfw_detection_result is None: return False diff --git a/api/core/tools/provider/builtin/novitaai/novitaai.py b/api/core/tools/provider/builtin/novitaai/novitaai.py index 1e7d9757c3d81e..d5e32eff29373a 100644 --- a/api/core/tools/provider/builtin/novitaai/novitaai.py +++ b/api/core/tools/provider/builtin/novitaai/novitaai.py @@ -8,23 +8,27 @@ class NovitaAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NovitaAiTxt2ImgTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', - 'prompt': 'a futuristic city with flying cars', - 'negative_prompt': '', - 'width': 128, - 'height': 128, - 'image_num': 1, - 'guidance_scale': 7.5, - 'seed': -1, - 'steps': 1, - }, + result = ( + NovitaAiTxt2ImgTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors", + "prompt": "a futuristic city with flying cars", + "negative_prompt": "", + "width": 128, + "height": 128, + "image_num": 1, + "guidance_scale": 7.5, + "seed": -1, + "steps": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index e63c8919575620..f76587bea1b96e 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -12,17 +12,18 @@ class NovitaAiCreateTileTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -30,21 +31,23 @@ def _invoke(self, results = [] results.append( - self.create_blob_message(blob=b64decode(client_result.image_file), - meta={'mime_type': f'image/{client_result.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(client_result.image_file), + meta={"mime_type": f"image/{client_result.image_type}"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index ec2927675e14b0..fe105f70a7262d 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -12,127 +12,137 @@ class NovitaAiModelQueryTool(BuiltinTool): - _model_query_endpoint = 'https://api.novita.ai/v3/model' + _model_query_endpoint = "https://api.novita.ai/v3/model" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + api_key - } + api_key = self.runtime.credentials.get("api_key") + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key} params = self._process_parameters(tool_parameters) - result_type = params.get('result_type') - del params['result_type'] + result_type = params.get("result_type") + del params["result_type"] models_data = self._query_models( models_data=[], headers=headers, params=params, - recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True + recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, ) - result_str = '' - if result_type == 'first sd_name': - result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else '' - elif result_type == 'first name sd_name pair': - result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else '' - elif result_type == 'sd_name array': - sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else [] + result_str = "" + if result_type == "first sd_name": + result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else "" + elif result_type == "first name sd_name pair": + result_str = ( + json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]}) + if len(models_data) > 0 + else "" + ) + elif result_type == "sd_name array": + sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(sd_name_array) - elif result_type == 'name array': - name_array = [model['name'] for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name array": + name_array = [model["name"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(name_array) - elif result_type == 'name sd_name pair array': - name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} - for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name sd_name pair array": + name_sd_name_pair_array = ( + [{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data] + if len(models_data) > 0 + else [] + ) result_str = json.dumps(name_sd_name_pair_array) - elif result_type == 'whole info array': + elif result_type == "whole info array": result_str = json.dumps(models_data) else: raise NotImplementedError return self.create_text_message(result_str) - def _query_models(self, models_data: list, headers: dict[str, Any], - params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: + def _query_models( + self, + models_data: list, + headers: dict[str, Any], + params: dict[str, Any], + pagination_cursor: str = "", + recursive: bool = True, + ) -> list: """ - query models + query models """ inside_params = deepcopy(params) - if pagination_cursor != '': - inside_params['pagination.cursor'] = pagination_cursor + if pagination_cursor != "": + inside_params["pagination.cursor"] = pagination_cursor response = ssrf_proxy.get( - url=str(URL(self._model_query_endpoint)), - headers=headers, - params=params, - timeout=(10, 60) + url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60) ) res_data = response.json() - models_data.extend(res_data['models']) + models_data.extend(res_data["models"]) - res_data_len = len(res_data['models']) - if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: + res_data_len = len(res_data["models"]) + if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False: # deduplicate df = DataFrame.from_dict(models_data) - df_unique = df.drop_duplicates(subset=['id']) - models_data = df_unique.to_dict('records') + df_unique = df.drop_duplicates(subset=["id"]) + models_data = df_unique.to_dict("records") return models_data return self._query_models( models_data=models_data, headers=headers, params=inside_params, - pagination_cursor=res_data['pagination']['next_cursor'] + pagination_cursor=res_data["pagination"]["next_cursor"], ) def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ process_parameters = deepcopy(parameters) res_parameters = {} # delete none or empty - keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""] for k in keys_to_delete: del process_parameters[k] - if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': - res_parameters['filter.query'] = process_parameters['query'] + if "query" in process_parameters and process_parameters.get("query") != "unspecified": + res_parameters["filter.query"] = process_parameters["query"] - if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': - res_parameters['filter.visibility'] = process_parameters['visibility'] + if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified": + res_parameters["filter.visibility"] = process_parameters["visibility"] - if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': - res_parameters['filter.source'] = process_parameters['source'] + if "source" in process_parameters and process_parameters.get("source") != "unspecified": + res_parameters["filter.source"] = process_parameters["source"] - if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': - res_parameters['filter.types'] = process_parameters['type'] + if "type" in process_parameters and process_parameters.get("type") != "unspecified": + res_parameters["filter.types"] = process_parameters["type"] - if 'is_sdxl' in process_parameters: - if process_parameters['is_sdxl'] == 'true': - res_parameters['filter.is_sdxl'] = True - elif process_parameters['is_sdxl'] == 'false': - res_parameters['filter.is_sdxl'] = False + if "is_sdxl" in process_parameters: + if process_parameters["is_sdxl"] == "true": + res_parameters["filter.is_sdxl"] = True + elif process_parameters["is_sdxl"] == "false": + res_parameters["filter.is_sdxl"] = False - res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') + res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name") - res_parameters['pagination.limit'] = 1 \ - if res_parameters.get('result_type') == 'first sd_name' \ - or res_parameters.get('result_type') == 'first name sd_name pair'\ + res_parameters["pagination.limit"] = ( + 1 + if res_parameters.get("result_type") == "first sd_name" + or res_parameters.get("result_type") == "first name sd_name pair" else 100 + ) return res_parameters diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 5fef3d2da71bb9..9632c163cf45da 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -13,17 +13,18 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -32,56 +33,58 @@ def _invoke(self, results = [] for image_encoded, image in zip(client_result.images_encoded, client_result.images): if self._is_hit_nsfw_detection(image, 0.8): - results = self.create_text_message(text='NSFW detected!') + results = self.create_text_message(text="NSFW detected!") break results.append( - self.create_blob_message(blob=b64decode(image_encoded), - meta={'mime_type': f'image/{image.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(image_encoded), + meta={"mime_type": f"image/{image.image_type}"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] - if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0: - del res_parameters['clip_skip'] + if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0: + del res_parameters["clip_skip"] - if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0: - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0: + del res_parameters["refiner_switch_at"] - if 'enabled_enterprise_plan' in res_parameters: - res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']} - del res_parameters['enabled_enterprise_plan'] + if "enabled_enterprise_plan" in res_parameters: + res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]} + del res_parameters["enabled_enterprise_plan"] - if 'nsfw_detection_level' in res_parameters: - res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level']) + if "nsfw_detection_level" in res_parameters: + res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"]) # process loras - if 'loras' in res_parameters: - res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) + if "loras" in res_parameters: + res_parameters["loras"] = self._extract_loras(res_parameters.get("loras")) # process embeddings - if 'embeddings' in res_parameters: - res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) + if "embeddings" in res_parameters: + res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings")) # process hires_fix - if 'hires_fix' in res_parameters: - res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) + if "hires_fix" in res_parameters: + res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix")) # process refiner - if 'refiner_switch_at' in res_parameters: - res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters: + res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at")) + del res_parameters["refiner_switch_at"] return res_parameters diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py index 42f321e919745e..b8e5ed24d6b43f 100644 --- a/api/core/tools/provider/builtin/onebot/onebot.py +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -5,8 +5,6 @@ class OneBotProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - if not credentials.get("ob11_http_url"): - raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') + raise ToolProviderCredentialValidationError("OneBot HTTP URL is required.") diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py index 2a1a9f86de5b44..9c95bbc2ae8d2d 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -11,54 +11,29 @@ class SendGroupMsg(BuiltinTool): """OneBot v11 Tool: Send Group Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_group_id = tool_parameters.get('group_id', '') - - message = tool_parameters.get('message', '') + send_group_id = tool_parameters.get("group_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_group_msg" resp = requests.post( url, - json={ - 'group_id': send_group_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"group_id": send_group_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send group message: {resp.text}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {resp.text}"}) - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send group message: {e}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py index 8ef4d72ab63b44..1174c7f07d002f 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -11,54 +11,29 @@ class SendPrivateMsg(BuiltinTool): """OneBot v11 Tool: Send Private Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_user_id = tool_parameters.get('user_id', '') - - message = tool_parameters.get('message', '') + send_user_id = tool_parameters.get("user_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_private_msg" resp = requests.post( url, - json={ - 'user_id': send_user_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"user_id": send_user_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send private message: {resp.text}' - } - ) - - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"error": f"Failed to send private message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send private message: {e}' - } - ) \ No newline at end of file + return self.create_json_message({"error": f"Failed to send private message: {e}"}) diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a370c9..9e40249aba6b40 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -5,7 +5,6 @@ def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): - url = "https://api.openweathermap.org/data/2.5/weather" params = {"q": city, "appid": api_key, "units": units, "lang": language} @@ -16,21 +15,15 @@ class OpenweatherProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: if "api_key" not in credentials or not credentials.get("api_key"): - raise ToolProviderCredentialValidationError( - "Open weather API key is required." - ) + raise ToolProviderCredentialValidationError("Open weather API key is required.") apikey = credentials.get("api_key") try: response = query_weather(api_key=apikey) if response.status_code == 200: pass else: - raise ToolProviderCredentialValidationError( - (response.json()).get("info") - ) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: - raise ToolProviderCredentialValidationError( - "Open weather API Key is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("Open weather API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py index d6c49a230f9db3..ed4ec487fa984a 100644 --- a/api/core/tools/provider/builtin/openweather/tools/weather.py +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -17,10 +17,7 @@ def _invoke( city = tool_parameters.get("city", "") if not city: return self.create_text_message("Please tell me your city") - if ( - "api_key" not in self.runtime.credentials - or not self.runtime.credentials.get("api_key") - ): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("OpenWeather API key is required.") units = tool_parameters.get("units", "metric") @@ -39,12 +36,9 @@ def _invoke( response = requests.get(url, params=params) if response.status_code == 200: - data = response.json() return self.create_text_message( - self.summary( - user_id=user_id, content=json.dumps(data, ensure_ascii=False) - ) + self.summary(user_id=user_id, content=json.dumps(data, ensure_ascii=False)) ) else: error_message = { @@ -55,6 +49,4 @@ def _invoke( return json.dumps(error_message) except Exception as e: - return self.create_text_message( - "Openweather API Key is invalid. {}".format(e) - ) + return self.create_text_message("Openweather API Key is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py index ff91edf18df4c0..80518853fb4a4b 100644 --- a/api/core/tools/provider/builtin/perplexity/perplexity.py +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -11,34 +11,26 @@ class PerplexityProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: headers = { "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { "model": "llama-3.1-sonar-small-128k-online", "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello" - } + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, ], "max_tokens": 5, "temperature": 0.1, "top_p": 0.9, - "stream": False + "stream": False, } try: response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() except requests.RequestException as e: - raise ToolProviderCredentialValidationError( - f"Failed to validate Perplexity API key: {str(e)}" - ) + raise ToolProviderCredentialValidationError(f"Failed to validate Perplexity API key: {str(e)}") if response.status_code != 200: raise ToolProviderCredentialValidationError( diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py index 5b1a263f9b2218..5ed4b9ca993483 100644 --- a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -8,65 +8,60 @@ PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + class PerplexityAITool(BuiltinTool): def _parse_response(self, response: dict) -> dict: """Parse the response from Perplexity AI API""" - if 'choices' in response and len(response['choices']) > 0: - message = response['choices'][0]['message'] + if "choices" in response and len(response["choices"]) > 0: + message = response["choices"][0]["message"] return { - 'content': message.get('content', ''), - 'role': message.get('role', ''), - 'citations': response.get('citations', []) + "content": message.get("content", ""), + "role": message.get("role", ""), + "citations": response.get("citations", []), } else: - return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []} + return {"content": "Unable to get a valid response", "role": "assistant", "citations": []} - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: headers = { "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { - "model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'), + "model": tool_parameters.get("model", "llama-3.1-sonar-small-128k-online"), "messages": [ - { - "role": "system", - "content": "Be precise and concise." - }, - { - "role": "user", - "content": tool_parameters['query'] - } + {"role": "system", "content": "Be precise and concise."}, + {"role": "user", "content": tool_parameters["query"]}, ], - "max_tokens": tool_parameters.get('max_tokens', 4096), - "temperature": tool_parameters.get('temperature', 0.7), - "top_p": tool_parameters.get('top_p', 1), - "top_k": tool_parameters.get('top_k', 5), - "presence_penalty": tool_parameters.get('presence_penalty', 0), - "frequency_penalty": tool_parameters.get('frequency_penalty', 1), - "stream": False + "max_tokens": tool_parameters.get("max_tokens", 4096), + "temperature": tool_parameters.get("temperature", 0.7), + "top_p": tool_parameters.get("top_p", 1), + "top_k": tool_parameters.get("top_k", 5), + "presence_penalty": tool_parameters.get("presence_penalty", 0), + "frequency_penalty": tool_parameters.get("frequency_penalty", 1), + "stream": False, } - - if 'search_recency_filter' in tool_parameters: - payload['search_recency_filter'] = tool_parameters['search_recency_filter'] - if 'return_citations' in tool_parameters: - payload['return_citations'] = tool_parameters['return_citations'] - if 'search_domain_filter' in tool_parameters: - if isinstance(tool_parameters['search_domain_filter'], str): - payload['search_domain_filter'] = [tool_parameters['search_domain_filter']] - elif isinstance(tool_parameters['search_domain_filter'], list): - payload['search_domain_filter'] = tool_parameters['search_domain_filter'] - + + if "search_recency_filter" in tool_parameters: + payload["search_recency_filter"] = tool_parameters["search_recency_filter"] + if "return_citations" in tool_parameters: + payload["return_citations"] = tool_parameters["return_citations"] + if "search_domain_filter" in tool_parameters: + if isinstance(tool_parameters["search_domain_filter"], str): + payload["search_domain_filter"] = [tool_parameters["search_domain_filter"]] + elif isinstance(tool_parameters["search_domain_filter"], list): + payload["search_domain_filter"] = tool_parameters["search_domain_filter"] response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) - + return [ self.create_json_message(valuable_res), - self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)) + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)), ] diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 05cd171b873327..ea3a477c30178d 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 58811d65e6fe9a..fedfdbd859b4f1 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -51,17 +51,12 @@ def run(self, query: str) -> str: try: # Retrieve the top-k results for the query docs = [ - f"Published: {result['pub_date']}\nTitle: {result['title']}\n" - f"Summary: {result['summary']}" + f"Published: {result['pub_date']}\nTitle: {result['title']}\n" f"Summary: {result['summary']}" for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) ] # Join the results and limit the character count - return ( - "\n\n".join(docs)[:self.doc_content_chars_max] - if docs - else "No good PubMed Result was found" - ) + return "\n\n".join(docs)[: self.doc_content_chars_max] if docs else "No good PubMed Result was found" except Exception as ex: return f"PubMed exception: {ex}" @@ -91,13 +86,7 @@ def load(self, query: str) -> list[dict]: return articles def retrieve_article(self, uid: str, webenv: str) -> dict: - url = ( - self.base_url_efetch - + "db=pubmed&retmode=xml&id=" - + uid - + "&webenv=" - + webenv - ) + url = self.base_url_efetch + "db=pubmed&retmode=xml&id=" + uid + "&webenv=" + webenv retry = 0 while True: @@ -108,10 +97,7 @@ def retrieve_article(self, uid: str, webenv: str) -> dict: if e.code == 429 and retry < self.max_retry: # Too Many Requests error # wait for an exponentially increasing amount of time - print( - f"Too Many Requests, " - f"waiting for {self.sleep_time:.2f} seconds..." - ) + print(f"Too Many Requests, " f"waiting for {self.sleep_time:.2f} seconds...") time.sleep(self.sleep_time) self.sleep_time *= 2 retry += 1 @@ -125,27 +111,21 @@ def retrieve_article(self, uid: str, webenv: str) -> dict: if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - title = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + title = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get abstract abstract = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - abstract = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + abstract = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get publication date pub_date = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - pub_date = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + pub_date = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Return article as dictionary article = { @@ -182,6 +162,7 @@ def _run( class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") + class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -198,14 +179,13 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = PubmedQueryRun(args_schema=PubMedInput) result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d012657fd8..8466b9a26b42b6 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -8,9 +8,6 @@ class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - QRCodeGeneratorTool().invoke(user_id='', - tool_parameters={ - 'content': 'Dify 123 😊' - }) + QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 5eede98f5eed6c..8aefc651313598 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -13,43 +13,44 @@ class QRCodeGeneratorTool(BuiltinTool): error_correction_levels: dict[str, int] = { - 'L': ERROR_CORRECT_L, # <=7% - 'M': ERROR_CORRECT_M, # <=15% - 'Q': ERROR_CORRECT_Q, # <=25% - 'H': ERROR_CORRECT_H, # <=30% + "L": ERROR_CORRECT_L, # <=7% + "M": ERROR_CORRECT_M, # <=15% + "Q": ERROR_CORRECT_Q, # <=25% + "H": ERROR_CORRECT_H, # <=30% } - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get text content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get border size - border = tool_parameters.get('border', 0) + border = tool_parameters.get("border", 0) if border < 0 or border > 100: - return self.create_text_message('Invalid parameter border') + return self.create_text_message("Invalid parameter border") # get error_correction - error_correction = tool_parameters.get('error_correction', '') + error_correction = tool_parameters.get("error_correction", "") if error_correction not in self.error_correction_levels.keys(): - return self.create_text_message('Invalid parameter error_correction') + return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) - return self.create_blob_message(blob=image_bytes, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_blob_message( + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) except Exception: - logging.exception(f'Failed to generate QR code for content: {content}') - return self.create_text_message('Failed to generate QR code') + logging.exception(f"Failed to generate QR code for content: {content}") + return self.create_text_message("Failed to generate QR code") def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: qr = QRCode( diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py index d38ae1b292675f..c498105979f13e 100644 --- a/api/core/tools/provider/builtin/regex/regex.py +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -9,10 +9,10 @@ class RegexProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: RegexExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'content': '1+(2+3)*4', - 'expression': r'(\d+)', + "content": "1+(2+3)*4", + "expression": r"(\d+)", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py index 5d8f013d0d012c..786b4694040030 100644 --- a/api/core/tools/provider/builtin/regex/tools/regex_extract.py +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -6,22 +6,23 @@ class RegexExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - content = tool_parameters.get('content', '').strip() + content = tool_parameters.get("content", "").strip() if not content: - return self.create_text_message('Invalid content') - expression = tool_parameters.get('expression', '').strip() + return self.create_text_message("Invalid content") + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = re.findall(expression, content) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file + return self.create_text_message(f"Failed to extract result, error: {str(e)}") diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py index 6fa4f05acd7b9d..109bba8b2d8f79 100644 --- a/api/core/tools/provider/builtin/searchapi/searchapi.py +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -13,11 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearchApi dify", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "SearchApi dify", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index dd780aeadcf36c..d632304a46bba4 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -80,25 +81,29 @@ def _process_response(res: dict, type: str) -> str: toret = "No good search result found" return toret + class GoogleTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 81c67c51a9a7ae..1544061c08f652 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -50,7 +51,16 @@ def _process_response(res: dict, type: str) -> str: if type == "text": if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): for item in res["jobs"]: - toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + toret += ( + "title: " + + item["title"] + + "\n" + + "company_name: " + + item["company_name"] + + "content: " + + item["description"] + + "\n" + ) if toret == "": toret = "No good search result found" @@ -62,16 +72,18 @@ def _process_response(res: dict, type: str) -> str: toret = "No good search result found" return toret + class GoogleJobsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") @@ -80,9 +92,11 @@ def _invoke(self, ltype = 1 if is_remote else None - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 5d2657dddd1972..95a7aad7365f84 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -68,25 +69,29 @@ def _process_response(res: dict, type: str) -> str: toret = "No good search result found" return toret + class GoogleNewsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 6345b338011e7a..88def504fca128 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -55,18 +56,20 @@ def _process_response(res: dict) -> str: return toret + class YoutubeTranscriptsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - video_id = tool_parameters['video_id'] - language = tool_parameters.get('language', "en") + video_id = tool_parameters["video_id"] + language = tool_parameters.get("language", "en") - api_key = self.runtime.credentials['searchapi_api_key'] + api_key = self.runtime.credentials["searchapi_api_key"] result = SearchAPI(api_key).run(video_id, language=language) return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index ab354003e6f567..b7bbcc60b1ed26 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearXNG", - "limit": 1, - "search_type": "general" - }, + user_id="", + tool_parameters={"query": "SearXNG", "limit": 1, "search_type": "general"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index dc835a8e8cbd5b..c5e339a108e5b2 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -23,18 +23,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url') + host = self.runtime.credentials.get("searxng_base_url") if not host: - raise Exception('SearXNG api is required') + raise Exception("SearXNG api is required") - response = requests.get(host, params={ - "q": tool_parameters.get('query'), - "format": "json", - "categories": tool_parameters.get('search_type', 'general') - }) + response = requests.get( + host, + params={ + "q": tool_parameters.get("query"), + "format": "json", + "categories": tool_parameters.get("search_type", "general"), + }, + ) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') + raise Exception(f"Error {response.status_code}: {response.text}") res = response.json().get("results", []) if not res: diff --git a/api/core/tools/provider/builtin/serper/serper.py b/api/core/tools/provider/builtin/serper/serper.py index 2a421093731477..cb1d090a9dd4b0 100644 --- a/api/core/tools/provider/builtin/serper/serper.py +++ b/api/core/tools/provider/builtin/serper/serper.py @@ -13,11 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/serper/tools/serper_search.py b/api/core/tools/provider/builtin/serper/tools/serper_search.py index 24facaf4ec3ae9..7baebbf95855e0 100644 --- a/api/core/tools/provider/builtin/serper/tools/serper_search.py +++ b/api/core/tools/provider/builtin/serper/tools/serper_search.py @@ -9,7 +9,6 @@ class SerperSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledgeGraph" in response: @@ -17,28 +16,19 @@ def _parse_response(self, response: dict) -> dict: result["description"] = response["knowledgeGraph"].get("description", "") if "organic" in response: result["organic"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - params = { - "q": tool_parameters['query'], - "gl": "us", - "hl": "en" - } - headers = { - 'X-API-KEY': self.runtime.credentials['serperapi_api_key'], - 'Content-Type': 'application/json' - } - response = requests.get(url=SERPER_API_URL, params=params,headers=headers) + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + params = {"q": tool_parameters["query"], "gl": "us", "hl": "en"} + headers = {"X-API-KEY": self.runtime.credentials["serperapi_api_key"], "Content-Type": "application/json"} + response = requests.get(url=SERPER_API_URL, params=params, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) return self.create_json_message(valuable_res) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py index 0df78280df3091..37a0b0755b1d39 100644 --- a/api/core/tools/provider/builtin/siliconflow/siliconflow.py +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -14,6 +14,4 @@ def _validate_credentials(self, credentials: dict) -> None: response = requests.get(url, headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError( - "SiliconFlow API key is invalid" - ) + raise ToolProviderCredentialValidationError("SiliconFlow API key is invalid") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index ed9f4be574703a..5fa99264841a34 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -5,17 +5,13 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -FLUX_URL = ( - "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" -) +FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" class FluxTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -36,9 +32,5 @@ def _invoke( res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e8134a6565b281..e7c3c28d7bf53c 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -12,11 +12,9 @@ class StableDiffusionTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -43,9 +41,5 @@ def _invoke( res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py index f47557f2ef5852..85e0de76755898 100644 --- a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py +++ b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py @@ -7,25 +7,27 @@ class SlackWebhookTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Incoming Webhooks - API Document: https://api.slack.com/messaging/webhooks + Incoming Webhooks + API Document: https://api.slack.com/messaging/webhooks """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - webhook_url = tool_parameters.get('webhook_url', '') + webhook_url = tool_parameters.get("webhook_url", "") - if not webhook_url.startswith('https://hooks.slack.com/'): + if not webhook_url.startswith("https://hooks.slack.com/"): return self.create_text_message( - f'Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL') + f"Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL" + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { @@ -38,6 +40,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message was sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message through webhook. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f8e3e..e0b1a58a3f679a 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -29,12 +29,8 @@ def _validate_credentials(self, credentials: dict) -> None: # 0 success, pass else: - raise ToolProviderCredentialValidationError( - "image generate error, code:{}".format(code) - ) + raise ToolProviderCredentialValidationError("image generate error, code:{}".format(code)) except Exception as e: - raise ToolProviderCredentialValidationError( - "APPID APISecret APIKey is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("APPID APISecret APIKey is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index c7b0de014f9104..a6f5570af26be6 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -47,26 +47,25 @@ def parse_url(request_url): u = Url(host, path, schema) return u + def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): u = parse_url(request_url) host = u.host path = u.path now = datetime.now() date = format_date_time(mktime(now.timetuple())) - signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( - host, date, method, path - ) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' ) + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") values = {"host": host, "date": date, "authorization": authorization} return request_url + "?" + urlencode(values) @@ -75,9 +74,7 @@ def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): def get_body(appid, text): body = { "header": {"app_id": appid, "uid": "123456789"}, - "parameter": { - "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} - }, + "parameter": {"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}}, "payload": {"message": {"text": [{"role": "user", "content": text}]}}, } return body @@ -85,13 +82,9 @@ def get_body(appid, text): def spark_response(text, appid, apikey, apisecret): host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" - url = assemble_ws_auth_url( - host, method="POST", api_key=apikey, api_secret=apisecret - ) + url = assemble_ws_auth_url(host, method="POST", api_key=apikey, api_secret=apisecret) content = get_body(appid, text) - response = requests.post( - url, json=content, headers={"content-type": "application/json"} - ).text + response = requests.post(url, json=content, headers={"content-type": "application/json"}).text return response @@ -105,19 +98,11 @@ def _invoke( invoke tools """ - if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( - "APPID" - ): + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get("APPID"): return self.create_text_message("APPID is required.") - if ( - "APISecret" not in self.runtime.credentials - or not self.runtime.credentials.get("APISecret") - ): + if "APISecret" not in self.runtime.credentials or not self.runtime.credentials.get("APISecret"): return self.create_text_message("APISecret is required.") - if ( - "APIKey" not in self.runtime.credentials - or not self.runtime.credentials.get("APIKey") - ): + if "APIKey" not in self.runtime.credentials or not self.runtime.credentials.get("APIKey"): return self.create_text_message("APIKey is required.") prompt = tool_parameters.get("prompt", "") diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py index 5bcc56a7248c1d..5959555318722e 100644 --- a/api/core/tools/provider/builtin/spider/spider.py +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -8,13 +8,13 @@ class SpiderProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - app = Spider(api_key=credentials['spider_api_key']) - app.scrape_url(url='https://spider.cloud') + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") except AttributeError as e: # Handle cases where NoneType is not iterable, which might indicate API issues - if 'NoneType' in str(e) and 'not iterable' in str(e): - raise ToolProviderCredentialValidationError('API is currently down, try again in 15 minutes', str(e)) + if "NoneType" in str(e) and "not iterable" in str(e): + raise ToolProviderCredentialValidationError("API is currently down, try again in 15 minutes", str(e)) else: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) except Exception as e: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index f0ed64867a18a1..3972e560c41b3d 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -65,9 +65,7 @@ def api_post( :return: The JSON response or the raw response stream if stream is True. """ headers = self._prepare_headers(content_type) - response = self._post_request( - f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream - ) + response = self._post_request(f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream) if stream: return response @@ -76,9 +74,7 @@ def api_post( else: self._handle_error(response, f"post to {endpoint}") - def api_get( - self, endpoint: str, stream: bool, content_type: str = "application/json" - ): + def api_get(self, endpoint: str, stream: bool, content_type: str = "application/json"): """ Send a GET request to the specified endpoint. @@ -86,9 +82,7 @@ def api_get( :return: The JSON decoded response. """ headers = self._prepare_headers(content_type) - response = self._get_request( - f"https://api.spider.cloud/v1/{endpoint}", headers, stream - ) + response = self._get_request(f"https://api.spider.cloud/v1/{endpoint}", headers, stream) if response.status_code == 200: return response.json() else: @@ -120,14 +114,12 @@ def scrape_url( # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: - params["return_format"] = "markdown" + params["return_format"] = "markdown" # Set limit to 1 params["limit"] = 1 - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def crawl_url( self, @@ -150,9 +142,7 @@ def crawl_url( if "return_format" not in params: params["return_format"] = "markdown" - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def links( self, @@ -168,9 +158,7 @@ def links( :param params: Optional parameters for the link retrieval request. :return: JSON response containing the links. """ - return self.api_post( - "links", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("links", {"url": url, **(params or {})}, stream, content_type) def extract_contacts( self, @@ -207,9 +195,7 @@ def label( :param params: Optional parameters to guide the labeling process. :return: JSON response with labeled data. """ - return self.api_post( - "pipeline/label", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("pipeline/label", {"url": url, **(params or {})}, stream, content_type) def _prepare_headers(self, content_type: str = "application/json"): return { @@ -230,10 +216,6 @@ def _delete_request(self, url: str, headers, stream=False): def _handle_error(self, response, action): if response.status_code in [402, 409, 500]: error_message = response.json().get("error", "Unknown error occurred") - raise Exception( - f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" - ) + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception( - f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" - ) + raise Exception(f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}") diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py index 40736cd40218c6..20d2daef550de1 100644 --- a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -6,41 +6,43 @@ class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # initialize the app object with the api key - app = Spider(api_key=self.runtime.credentials['spider_api_key']) + app = Spider(api_key=self.runtime.credentials["spider_api_key"]) + + url = tool_parameters["url"] + mode = tool_parameters["mode"] - url = tool_parameters['url'] - mode = tool_parameters['mode'] - options = { - 'limit': tool_parameters.get('limit', 0), - 'depth': tool_parameters.get('depth', 0), - 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], - 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], - 'readability': tool_parameters.get('readability', False), + "limit": tool_parameters.get("limit", 0), + "depth": tool_parameters.get("depth", 0), + "blacklist": tool_parameters.get("blacklist", "").split(",") if tool_parameters.get("blacklist") else [], + "whitelist": tool_parameters.get("whitelist", "").split(",") if tool_parameters.get("whitelist") else [], + "readability": tool_parameters.get("readability", False), } result = "" try: - if mode == 'scrape': + if mode == "scrape": scrape_result = app.scrape_url( - url=url, + url=url, params=options, ) for i in scrape_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" - elif mode == 'crawl': + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" + elif mode == "crawl": crawl_result = app.crawl_url( - url=tool_parameters['url'], + url=tool_parameters["url"], params=options, ) for i in crawl_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" except Exception as e: return self.create_text_message("An error occurred", str(e)) diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index b31d786178dd63..f09d81ac270288 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This class is responsible for providing the stability tool. """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ This method is responsible for validating the credentials. diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py index a4788fd869ce1b..c3b7edbefa2447 100644 --- a/api/core/tools/provider/builtin/stability/tools/base.py +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -9,26 +9,23 @@ def sd_validate_credentials(self, credentials: dict): """ This method is responsible for validating the credentials. """ - api_key = credentials.get('api_key', '') + api_key = credentials.get("api_key", "") if not api_key: - raise ToolProviderCredentialValidationError('API key is required.') - + raise ToolProviderCredentialValidationError("API key is required.") + response = requests.get( - URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + URL("https://api.stability.ai") / "v1" / "user" / "account", headers=self.generate_authorization_headers(credentials), - timeout=(5, 30) + timeout=(5, 30), ) if not response.ok: - raise ToolProviderCredentialValidationError('Invalid API key.') + raise ToolProviderCredentialValidationError("Invalid API key.") return True - + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: """ This method is responsible for generating the authorization headers. """ - return { - 'Authorization': f'Bearer {credentials.get("api_key", "")}' - } - \ No newline at end of file + return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 41236f7b433cef..c33e3bd78fc913 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): """ This class is responsible for providing the stable diffusion tool. """ + model_endpoint_map: dict[str, str] = { - 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + "sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "core": "https://api.stability.ai/v2beta/stable-image/generate/core", } def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -22,39 +23,34 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Invoke the tool. """ payload = { - 'prompt': tool_parameters.get('prompt', ''), - 'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'), - 'mode': 'text-to-image', - 'seed': tool_parameters.get('seed', 0), - 'output_format': 'png', + "prompt": tool_parameters.get("prompt", ""), + "aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"), + "mode": "text-to-image", + "seed": tool_parameters.get("seed", 0), + "output_format": "png", } - model = tool_parameters.get('model', 'core') + model = tool_parameters.get("model", "core") - if model in ['sd3', 'sd3-turbo']: - payload['model'] = tool_parameters.get('model') + if model in ["sd3", "sd3-turbo"]: + payload["model"] = tool_parameters.get("model") - if not model == 'sd3-turbo': - payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + if not model == "sd3-turbo": + payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( - self.model_endpoint_map[tool_parameters.get('model', 'core')], + self.model_endpoint_map[tool_parameters.get("model", "core")], headers={ - 'accept': 'image/*', + "accept": "image/*", **self.generate_authorization_headers(self.runtime.credentials), }, - files={ - key: (None, str(value)) for key, value in payload.items() - }, - timeout=(5, 30) + files={key: (None, str(value)) for key, value in payload.items()}, + timeout=(5, 30), ) if not response.status_code == 200: raise Exception(response.text) - + return self.create_blob_message( - blob=response.content, meta={ - 'mime_type': 'image/png' - }, - save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 317d705f7c2c7c..abaa297cf36eb1 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -15,4 +15,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 4be9207d66abcc..c31e1780674c99 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -18,19 +18,17 @@ # Prompts "prompt": "", "negative_prompt": "", - # "styles": [], - # Seeds + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, "seed_resize_from_w": -1, - # Samplers "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", - # Latent Space Options "batch_size": 1, "n_iter": 1, @@ -42,9 +40,9 @@ # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - # "eta": 0, - # "denoising_strength": 0.75, - # "s_min_uncond": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, # "s_churn": 0, # "s_tmax": 0, # "s_tmin": 0, @@ -73,7 +71,6 @@ "hr_negative_prompt": "", # Task Options # "force_task_id": "", - # Script Options # "script_name": "", "script_args": [], @@ -82,131 +79,130 @@ "save_images": False, "alwayson_scripts": {}, # "infotext": "", - } class StableDiffusionTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # base url - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - return self.create_text_message('Please input base_url') + return self.create_text_message("Please input base_url") - if tool_parameters.get('model'): - self.runtime.credentials['model'] = tool_parameters['model'] + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] - model = self.runtime.credentials.get('model', None) + model = self.runtime.credentials.get("model", None) if not model: - return self.create_text_message('Please input model') - + return self.create_text_message("Please input model") + # set model try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') - response = post(url, data=json.dumps({ - 'sd_model_checkpoint': model - })) + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post(url, data=json.dumps({"sd_model_checkpoint": model})) if response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") except Exception as e: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") # get image id and image variable - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") image_variable = self.get_default_image_variable() # Return text2img if there's no image ID or no image variable if not image_id or not image_variable: - return self.text2img(base_url=base_url,tool_parameters=tool_parameters) + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) # Proceed with image-to-image generation - return self.img2img(base_url=base_url,tool_parameters=tool_parameters) + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - validate models + validate models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - raise ToolProviderCredentialValidationError('Please input base_url') - model = self.runtime.credentials.get('model', None) + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) if not model: - raise ToolProviderCredentialValidationError('Please input model') + raise ToolProviderCredentialValidationError("Please input model") - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=10) if response.status_code == 404: # try draw a picture self._invoke( - user_id='test', + user_id="test", tool_parameters={ - 'prompt': 'a cat', - 'width': 1024, - 'height': 1024, - 'steps': 1, - 'lora': '', - } + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, ) elif response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to get models') + raise ToolProviderCredentialValidationError("Failed to get models") else: - models = [d['model_name'] for d in response.json()] + models = [d["model_name"] for d in response.json()] if len([d for d in models if d == model]) > 0: return self.create_text_message(json.dumps(models)) else: - raise ToolProviderCredentialValidationError(f'model {model} does not exist') + raise ToolProviderCredentialValidationError(f"model {model} does not exist") except Exception as e: - raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") def get_sd_models(self) -> list[str]: """ - get sd models + get sd models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['model_name'] for d in response.json()] + return [d["model_name"] for d in response.json()] except Exception as e: return [] - + def get_sample_methods(self) -> list[str]: """ - get sample method + get sample method """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['name'] for d in response.json()] + return [d["name"] for d in response.json()] except Exception as e: return [] - def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # Fetch the binary data of the image image_variable = self.get_default_image_variable() image_binary = self.get_variable_file(image_variable.name) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") # Convert image to RGB and save as PNG try: @@ -220,14 +216,14 @@ def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) # set image options - model = tool_parameters.get('model', '') + model = tool_parameters.get("model", "") draw_options_image = { - "init_images": [b64encode(image_binary).decode('utf-8')], + "init_images": [b64encode(image_binary).decode("utf-8")], "denoising_strength": 0.9, "restore_faces": False, "script_args": [], "override_settings": {"sd_model_checkpoint": model}, - "resize_mode":0, + "resize_mode": 0, "image_cfg_scale": 0, # "mask": None, "mask_blur_x": 4, @@ -247,136 +243,142 @@ def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt + draw_options["prompt"] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] - - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) + except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") - def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt - draw_options['override_settings']['sd_model_checkpoint'] = model + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model - try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] - - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) + except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") def get_runtime_parameters(self) -> list[ToolParameter]: parameters = [ - ToolParameter(name='prompt', - label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), - human_description=I18nObject( - en_US='Image prompt, you can check the official documentation of Stable Diffusion', - zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', - required=True), + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', - label=I18nObject(en_US='image_id', zh_Hans='image_id'), - human_description=I18nObject( - en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', - zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', - required=True, - options=[ToolParameterOption( - value=i.name, - label=I18nObject(en_US=i.name, zh_Hans=i.name) - ) for i in self.list_default_image_variables()]) + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) ) - + if self.runtime.credentials: try: models = self.get_sd_models() if len(models) != 0: parameters.append( - ToolParameter(name='model', - label=I18nObject(en_US='Model', zh_Hans='Model'), - human_description=I18nObject( - en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=models[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in models]) + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) ) - + except: pass - + sample_methods = self.get_sample_methods() if len(sample_methods) != 0: parameters.append( - ToolParameter(name='sampler_name', - label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), - human_description=I18nObject( - en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=sample_methods[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in sample_methods]) + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], ) + ) return parameters diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index de64c84997f7ca..9680c633cc701c 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -11,16 +11,15 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "intitle": "Test", - "sort": "relevance", + "sort": "relevance", "order": "desc", "site": "stackoverflow", "accepted": True, - "pagesize": 1 + "pagesize": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py index f8e17108444084..534532009501f5 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py @@ -17,7 +17,9 @@ class FetchAnsByStackExQuesIDInput(BaseModel): class FetchAnsByStackExQuesIDTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = FetchAnsByStackExQuesIDInput(**tool_parameters) params = { @@ -26,7 +28,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn "order": input.order, "sort": input.sort, "pagesize": input.pagesize, - "page": input.page + "page": input.page, } response = requests.get(f"https://api.stackexchange.com/2.3/questions/{input.id}/answers", params=params) @@ -34,4 +36,4 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py index 8436433c323cd1..4a25a808adf26a 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py @@ -9,26 +9,28 @@ class SearchStackExQuestionsInput(BaseModel): intitle: str = Field(..., description="The search query.") - sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") + sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") order: str = Field(..., description="asc or desc") site: str = Field(..., description="The Stack Exchange site.") tagged: str = Field(None, description="Semicolon-separated tags to include.") nottagged: str = Field(None, description="Semicolon-separated tags to exclude.") - accepted: bool = Field(..., description="true for only accepted answers, false otherwise") + accepted: bool = Field(..., description="true for only accepted answers, false otherwise") pagesize: int = Field(..., description="Number of results per page") class SearchStackExQuestionsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = SearchStackExQuestionsInput(**tool_parameters) params = { "intitle": input.intitle, "sort": input.sort, - "order": input.order, + "order": input.order, "site": input.site, "accepted": input.accepted, - "pagesize": input.pagesize + "pagesize": input.pagesize, } if input.tagged: params["tagged"] = input.tagged @@ -40,4 +42,4 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py index e809b04546aef5..b24f730c95a7c2 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.py +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -13,13 +13,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "cute girl, blue eyes, white hair, anime style", "size": "1024x1024", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index c571f546757b94..0b92b122bf59bc 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -9,61 +9,67 @@ class StepfunTool(BuiltinTool): - """ Stepfun Image Generation Tool """ - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Stepfun Image Generation Tool""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com') - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get("stepfun_base_url", "https://api.stepfun.com") + base_url = str(URL(base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['stepfun_api_key'], + api_key=self.runtime.credentials["stepfun_api_key"], base_url=base_url, ) extra_body = {} - model = tool_parameters.get('model', 'step-1x-medium') + model = tool_parameters.get("model", "step-1x-medium") if not model: - return self.create_text_message('Please input model name') + return self.create_text_message("Please input model name") # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") - seed = tool_parameters.get('seed', 0) + seed = tool_parameters.get("seed", 0) if seed > 0: - extra_body['seed'] = seed - steps = tool_parameters.get('steps', 0) + extra_body["seed"] = seed + steps = tool_parameters.get("steps", 0) if steps > 0: - extra_body['steps'] = steps - negative_prompt = tool_parameters.get('negative_prompt', '') + extra_body["steps"] = steps + negative_prompt = tool_parameters.get("negative_prompt", "") if negative_prompt: - extra_body['negative_prompt'] = negative_prompt + extra_body["negative_prompt"] = negative_prompt # call openapi stepfun model response = client.images.generate( prompt=prompt, model=model, - size=tool_parameters.get('size', '1024x1024'), - n=tool_parameters.get('n', 1), - extra_body= extra_body + size=tool_parameters.get("size", "1024x1024"), + n=tool_parameters.get("n", 1), + extra_body=extra_body, ) print(response) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index e376d99d6bb951..a702b0a74e6131 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -13,7 +13,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", "search_depth": "basic", @@ -22,9 +22,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "include_raw_content": False, "max_results": 5, "include_domains": "", - "exclude_domains": "" + "exclude_domains": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 0200df3c8a4c31..ca6d8633e4b0af 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -36,15 +36,23 @@ def raw_results(self, params: dict[str, Any]) -> dict: """ params["api_key"] = self.api_key - if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': - params['exclude_domains'] = params['exclude_domains'].split() + if ( + "exclude_domains" in params + and isinstance(params["exclude_domains"], str) + and params["exclude_domains"] != "None" + ): + params["exclude_domains"] = params["exclude_domains"].split() else: - params['exclude_domains'] = [] - if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': - params['include_domains'] = params['include_domains'].split() + params["exclude_domains"] = [] + if ( + "include_domains" in params + and isinstance(params["include_domains"], str) + and params["include_domains"] != "None" + ): + params["include_domains"] = params["include_domains"].split() else: - params['include_domains'] = [] - + params["include_domains"] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() @@ -91,9 +99,7 @@ class TavilySearchTool(BuiltinTool): A tool for searching Tavily using a given query. """ - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Tavily search tool with the given user ID and tool parameters. @@ -115,4 +121,4 @@ def _invoke( if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) \ No newline at end of file + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py index 1f96be06b0200d..cb7d7bd8bb2c41 100644 --- a/api/core/tools/provider/builtin/tianditu/tianditu.py +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -12,10 +12,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: runtime={ "credentials": credentials, } - ).invoke(user_id='', - tool_parameters={ - 'content': '北京', - 'specify': '156110000', - }) + ).invoke( + user_id="", + tool_parameters={ + "content": "北京", + "specify": "156110000", + }, + ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py index 484a3768c851df..690a0aed6f5aff 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -8,26 +8,26 @@ class GeocoderTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = 'http://api.tianditu.gov.cn/geocoder' - - keyword = tool_parameters.get('keyword', '') + base_url = "http://api.tianditu.gov.cn/geocoder" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + params = { - 'keyWord': keyword, + "keyWord": keyword, } - - result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + result = requests.get(base_url + "?ds=" + json.dumps(params, ensure_ascii=False) + "&tk=" + tk).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py index 08a5b8ef42a8c8..798dd94d335654 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -8,38 +8,51 @@ class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/v2/search' - - keyword = tool_parameters.get('keyword', '') + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/v2/search" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - baseAddress = tool_parameters.get('baseAddress', '') + return self.create_text_message("Invalid parameter keyword") + + baseAddress = tool_parameters.get("baseAddress", "") if not baseAddress: - return self.create_text_message('Invalid parameter baseAddress') - - tk = self.runtime.credentials['tianditu_api_key'] - - base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() - + return self.create_text_message("Invalid parameter baseAddress") + + tk = self.runtime.credentials["tianditu_api_key"] + + base_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": baseAddress, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + params = { - 'keyWord': keyword, - 'queryRadius': 5000, - 'queryType': 3, - 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], - 'start': 0, - 'count': 100, + "keyWord": keyword, + "queryRadius": 5000, + "queryType": 3, + "pointLonlat": base_coords["location"]["lon"] + "," + base_coords["location"]["lat"], + "start": 0, + "count": 100, } - - result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + result = requests.get( + base_url + "?postStr=" + json.dumps(params, ensure_ascii=False) + "&type=query&tk=" + tk + ).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index ecac4404ca28b0..93803d7937042d 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -8,29 +8,42 @@ class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/staticimage' - - keyword = tool_parameters.get('keyword', '') + + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/staticimage" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - - keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() - coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] - - result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content - - return self.create_blob_message(blob=result, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + + keyword_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": keyword, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + coords = keyword_coords["location"]["lon"] + "," + keyword_coords["location"]["lat"] + + result = requests.get( + base_url + "?center=" + coords + "&markers=" + coords + "&width=400&height=300&zoom=14&tk=" + tk + ).content + + return self.create_blob_message( + blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 833ae194ef840c..e4df8d616cba38 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -9,9 +9,8 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CurrentTimeTool().invoke( - user_id='', + user_id="", tool_parameters={}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 90c01665e6e6a1..cc38739c16f04b 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,21 +8,22 @@ class CurrentTimeTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get timezone - tz = tool_parameters.get('timezone', 'UTC') - fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' - if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') - + tz = tool_parameters.get("timezone", "UTC") + fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" + if tz == "UTC": + return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + try: tz = pytz_timezone(tz) except: - return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file + return self.create_text_message(f"Invalid timezone: {tz}") + return self.create_text_message(f"{datetime.now(tz).strftime(fm)}") diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py index 4461cb5a32d14e..b327e54e171048 100644 --- a/api/core/tools/provider/builtin/time/tools/weekday.py +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -7,25 +7,26 @@ class WeekdayTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Calculate the day of the week for a given date + Calculate the day of the week for a given date """ - year = tool_parameters.get('year') - month = tool_parameters.get('month') - day = tool_parameters.get('day') + year = tool_parameters.get("year") + month = tool_parameters.get("month") + day = tool_parameters.get("day") date_obj = self.convert_datetime(year, month, day) if not date_obj: - return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" - return self.create_text_message(f'{readable_date} is {weekday_name}.') + return self.create_text_message(f"{readable_date} is {weekday_name}.") @staticmethod def convert_datetime(year, month, day) -> datetime | None: diff --git a/api/core/tools/provider/builtin/trello/tools/create_board.py b/api/core/tools/provider/builtin/trello/tools/create_board.py index 2655602afa82d3..5a61d221578995 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_board.py @@ -22,19 +22,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_name = tool_parameters.get("name") if not (api_key and token and board_name): return self.create_text_message("Missing required parameters: API key, token, or board name.") url = "https://api.trello.com/1/boards/" - query_params = { - 'name': board_name, - 'key': api_key, - 'token': token - } + query_params = {"name": board_name, "key": api_key, "token": token} try: response = requests.post(url, params=query_params) @@ -43,5 +39,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to create board") board = response.json() - return self.create_text_message(text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}") - + return self.create_text_message( + text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py index f5b156cb44c2ee..26f12864c3942b 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py @@ -22,20 +22,16 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('id') - list_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("id") + list_name = tool_parameters.get("name") if not (api_key and token and board_id and list_name): return self.create_text_message("Missing required parameters: API key, token, board ID, or list name.") url = f"https://api.trello.com/1/boards/{board_id}/lists" - params = { - 'name': list_name, - 'key': api_key, - 'token': token - } + params = {"name": list_name, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -44,5 +40,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to create list") new_list = response.json() - return self.create_text_message(text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}.") - + return self.create_text_message( + text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py index 74b73b40e54f5d..dfc013a6b8bb73 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py @@ -22,15 +22,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") # Ensure required parameters are present - if 'name' not in tool_parameters or 'idList' not in tool_parameters: + if "name" not in tool_parameters or "idList" not in tool_parameters: return self.create_text_message("Missing required parameters: name or idList.") url = "https://api.trello.com/1/cards" - params = {**tool_parameters, 'key': api_key, 'token': token} + params = {**tool_parameters, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -39,5 +39,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, except requests.exceptions.RequestException as e: return self.create_text_message("Failed to create card") - return self.create_text_message(text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}.") - + return self.create_text_message( + text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/delete_board.py b/api/core/tools/provider/builtin/trello/tools/delete_board.py index 29df3fda2d23ec..9dbd8f78d5e65f 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_board.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_board.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,4 +38,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to delete board") return self.create_text_message(text=f"Board with ID {board_id} deleted successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/delete_card.py b/api/core/tools/provider/builtin/trello/tools/delete_card.py index 2ced5f6c14f9f7..960c3055fe94a4 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_card.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_card.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") @@ -38,4 +38,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to delete card") return self.create_text_message(text=f"Card with ID {card_id} has been successfully deleted.") - diff --git a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py index f9d554c6fb0478..0c5ed9ea8533ff 100644 --- a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py +++ b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py @@ -28,9 +28,7 @@ def _invoke( token = self.runtime.credentials.get("trello_api_token") if not (api_key and token): - return self.create_text_message( - "Missing Trello API key or token in credentials." - ) + return self.create_text_message("Missing Trello API key or token in credentials.") # Including board filter in the request if provided board_filter = tool_parameters.get("boards", "open") @@ -48,7 +46,5 @@ def _invoke( return self.create_text_message("No boards found in Trello.") # Creating a string with both board names and IDs - boards_info = ", ".join( - [f"{board['name']} (ID: {board['id']})" for board in boards] - ) + boards_info = ", ".join([f"{board['name']} (ID: {board['id']})" for board in boards]) return self.create_text_message(text=f"Boards: {boards_info}") diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py index 5678d8f8d76d7c..03510f196488b1 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,6 +38,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] except requests.exceptions.RequestException as e: return self.create_text_message("Failed to retrieve board actions") - actions_summary = "\n".join([f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions]) + actions_summary = "\n".join( + [f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions] + ) return self.create_text_message(text=f"Actions for Board ID {board_id}:\n{actions_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py index ee6cb065e5a9fa..5b41b128d07ca8 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -63,4 +63,3 @@ def format_board_details(self, board: dict) -> str: f"Background Color: {board['prefs']['backgroundColor']}" ) return details - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py index 1abb688750af53..e3bed2e6e620bc 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] cards_summary = "\n".join([f"{card['name']} (ID: {card['id']})" for card in cards]) return self.create_text_message(text=f"Cards for Board ID {board_id}:\n{cards_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py index 375ead5b1d232a..4d8854747c6963 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py @@ -22,10 +22,10 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') - filter = tool_parameters.get('filter') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") + filter = tool_parameters.get("filter") if not (api_key and token and board_id and filter): return self.create_text_message("Missing required parameters: API key, token, board ID, or filter.") @@ -40,5 +40,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to retrieve filtered cards") card_details = "\n".join([f"{card['name']} (ID: {card['id']})" for card in filtered_cards]) - return self.create_text_message(text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}") - + return self.create_text_message( + text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py index 7b9b9cf24b7543..ca8aa9c2d5daf4 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] lists_info = "\n".join([f"{list['name']} (ID: {list['id']})" for list in lists]) return self.create_text_message(text=f"Lists on Board ID {board_id}:\n{lists_info}") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_board.py b/api/core/tools/provider/builtin/trello/tools/update_board.py index 7ad6ac2e64ef46..62681eea6b4a9f 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_board.py +++ b/api/core/tools/provider/builtin/trello/tools/update_board.py @@ -22,9 +22,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.pop('boardId', None) + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.pop("boardId", None) if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -33,8 +33,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, # Removing parameters not intended for update action or with None value params = {k: v for k, v in tool_parameters.items() if v is not None} - params['key'] = api_key - params['token'] = token + params["key"] = api_key + params["token"] = token try: response = requests.put(url, params=params) @@ -44,4 +44,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, updated_board = response.json() return self.create_text_message(text=f"Board '{updated_board['name']}' updated successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_card.py b/api/core/tools/provider/builtin/trello/tools/update_card.py index 417344350cbc18..26113f12290613 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_card.py +++ b/api/core/tools/provider/builtin/trello/tools/update_card.py @@ -22,17 +22,17 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") # Constructing the URL and the payload for the PUT request url = f"https://api.trello.com/1/cards/{card_id}" - params = {k: v for k, v in tool_parameters.items() if v is not None and k != 'id'} - params.update({'key': api_key, 'token': token}) + params = {k: v for k, v in tool_parameters.items() if v is not None and k != "id"} + params.update({"key": api_key, "token": token}) try: response = requests.put(url, params=params) diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index 84ecd208037037..e0dca50ec99aee 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -9,17 +9,17 @@ class TrelloProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: """Validate Trello API credentials by making a test API call. - + Args: credentials (dict[str, Any]): The Trello API credentials to validate. - + Raises: ToolProviderCredentialValidationError: If the credentials are invalid. """ api_key = credentials.get("trello_api_key") token = credentials.get("trello_api_token") url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}" - + try: response = requests.get(url) response.raise_for_status() # Raises an HTTPError for bad responses @@ -32,4 +32,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors raise ToolProviderCredentialValidationError("Error validating Trello credentials") - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 1c52589956c708..822d0c0ebdd3a8 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -32,17 +32,14 @@ class TwilioAPIWrapper(BaseModel): must be empty. """ - @field_validator('client', mode='before') + @field_validator("client", mode="before") @classmethod def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: from twilio.rest import Client except ImportError: - raise ImportError( - "Could not import twilio python package. " - "Please install it with `pip install twilio`." - ) + raise ImportError("Could not import twilio python package. " "Please install it with `pip install twilio`.") account_sid = values.get("account_sid") auth_token = values.get("auth_token") values["from_number"] = values.get("from_number") @@ -91,9 +88,7 @@ def _invoke( if to_number.startswith("whatsapp:"): from_number = f"whatsapp: {from_number}" - twilio = TwilioAPIWrapper( - account_sid=account_sid, auth_token=auth_token, from_number=from_number - ) + twilio = TwilioAPIWrapper(account_sid=account_sid, auth_token=auth_token, from_number=from_number) # Sending the message through Twilio result = twilio.run(message, to_number) diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 06f276053a9c63..b1d100aad93dba 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -14,7 +14,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: account_sid = credentials["account_sid"] auth_token = credentials["auth_token"] from_number = credentials["from_number"] - + # Initialize twilio client client = Client(account_sid, auth_token) @@ -27,4 +27,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index ab1fd71df5e191..84724e921a1b00 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -13,13 +13,13 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", "url": "https://vanna.ai/Chinook.sqlite", - "query": "What are the top 10 customers by sales?" + "query": "What are the top 10 customers by sales?", }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py index 1506ac0c9ded93..8e1b0977766706 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -1 +1 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file +VECTORIZER_ICON_PNG = "iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC" diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index c6ec1980342d75..3ba4996be1ff8d 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -10,65 +10,60 @@ class VectorizerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - api_key_name = self.runtime.credentials.get('api_key_name', None) - api_key_value = self.runtime.credentials.get('api_key_value', None) - mode = tool_parameters.get('mode', 'test') - if mode == 'production': - mode = 'preview' + api_key_name = self.runtime.credentials.get("api_key_name", None) + api_key_value = self.runtime.credentials.get("api_key_value", None) + mode = tool_parameters.get("mode", "test") + if mode == "production": + mode = "preview" if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError('Please input api key name and value') + raise ToolProviderCredentialValidationError("Please input api key name and value") - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") if not image_id: - return self.create_text_message('Please input image id') - - if image_id.startswith('__test_'): + return self.create_text_message("Please input image id") + + if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") response = post( - 'https://vectorizer.ai/api/v1/vectorize', - files={ - 'image': image_binary - }, - data={ - 'mode': mode - } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), - timeout=30 + "https://vectorizer.ai/api/v1/vectorize", + files={"image": image_binary}, + data={"mode": mode} if mode == "test" else {}, + auth=(api_key_name, api_key_value), + timeout=30, ) if response.status_code != 200: raise Exception(response.text) - + return [ - self.create_text_message('the vectorized svg is saved as an image.'), - self.create_blob_message(blob=response.content, - meta={'mime_type': 'image/svg+xml'}) + self.create_text_message("the vectorized svg is saved as an image."), + self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), ] - + def get_runtime_parameters(self) -> list[ToolParameter]: """ override the runtime parameters """ return [ ToolParameter.get_simple_instance( - name='image_id', - llm_description=f'the image id that you want to vectorize, \ + name="image_id", + llm_description=f"the image id that you want to vectorize, \ and the image id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}', + {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, required=True, - options=[i.name for i in self.list_default_image_variables()] + options=[i.name for i in self.list_default_image_variables()], ) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3f89a83500da9f..3b868572f93bae 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "mode": "test", - "image_id": "__test_123" - }, + user_id="", + tool_parameters={"mode": "test", "image_id": "__test_123"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 3d098e6768a8e7..12670b4b8b9289 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -6,23 +6,24 @@ class WebscraperTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: - url = tool_parameters.get('url', '') - user_agent = tool_parameters.get('user_agent', '') + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") if not url: - return self.create_text_message('Please input url') + return self.create_text_message("Please input url") # get webpage result = self.get_url(url, user_agent=user_agent) - if tool_parameters.get('generate_summary'): + if tool_parameters.get("generate_summary"): # summarize and return return self.create_text_message(self.summary(user_id=user_id, content=result)) else: diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 1e60fdb2939d93..3c51393ac64cc4 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -13,12 +13,11 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - 'url': 'https://www.google.com', - 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + "url": "https://www.google.com", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 91283059229645..293f4f63297120 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -50,14 +50,16 @@ def parse_results(res: dict) -> str: for job in jobs[:10]: try: string.append( - "\n".join([ - f"Position: {job['position']}", - f"Employer: {job['employer']}", - f"Location: {job['location']}", - f"Link: {job['link']}", - f"""Highest: {", ".join(list(job["highlights"]))}""", - "---", - ]) + "\n".join( + [ + f"Position: {job['position']}", + f"Employer: {job['employer']}", + f"Location: {job['location']}", + f"Link: {job['link']}", + f"""Highest: {", ".join(list(job["highlights"]))}""", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index e9c0744f054aa3..9b5482fe183e18 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -53,13 +53,15 @@ def parse_results(res: dict) -> str: r = requests.get(entry["link"]) final_link = r.history[-1].headers["Location"] string.append( - "\n".join([ - f"Title: {entry['title']}", - f"Link: {final_link}", - f"Source: {entry['source']['title']}", - f"Published: {entry['published']}", - "---", - ]) + "\n".join( + [ + f"Title: {entry['title']}", + f"Link: {final_link}", + f"Source: {entry['source']['title']}", + f"Published: {entry['published']}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 0030a03c06a5d8..798d059b512edf 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -55,14 +55,16 @@ def parse_results(res: dict) -> str: link = article["link"] authors = [author["name"] for author in article["author"]["authors"]] string.append( - "\n".join([ - f"Title: {article['title']}", - f"Link: {link}", - f"Description: {article['description']}", - f"Cite: {article['cite']}", - f"Authors: {', '.join(authors)}", - "---", - ]) + "\n".join( + [ + f"Title: {article['title']}", + f"Link: {link}", + f"Description: {article['description']}", + f"Cite: {article['cite']}", + f"Authors: {', '.join(authors)}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/web_search.py b/api/core/tools/provider/builtin/websearch/tools/web_search.py index 4f57c27caf5257..fe363ac7a4d5d0 100644 --- a/api/core/tools/provider/builtin/websearch/tools/web_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/web_search.py @@ -49,12 +49,14 @@ def parse_results(res: dict) -> str: for result in results: try: string.append( - "\n".join([ - f"Title: {result['title']}", - f"Link: {result['link']}", - f"Description: {result['description'].strip()}", - "---", - ]) + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['link']}", + f"Description: {result['description'].strip()}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index fb44b70f4ec6b6..545d9f4f8d6335 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -8,41 +8,41 @@ class WecomGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - message_type = tool_parameters.get('message_type', 'text') - if message_type == 'markdown': + message_type = tool_parameters.get("message_type", "text") + if message_type == "markdown": payload = { - "msgtype": 'markdown', + "msgtype": "markdown", "markdown": { "content": content, - } + }, } else: payload = { - "msgtype": 'text', + "msgtype": "text", "text": { "content": content, - } + }, } - api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' + api_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'key': hook_key, + "key": hook_key, } try: @@ -51,6 +51,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 0796cd2392a68b..67efcf0954739e 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -83,7 +83,6 @@ def _run( class WikiPediaSearchTool(BuiltinTool): - def _invoke( self, user_id: str, diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index f8038714a5f524..178bf7b0ceb2e9 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "misaka mikoto", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 8cb9c10ddf499a..9dc5bed824d715 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -8,29 +8,24 @@ class WolframAlphaTool(BuiltinTool): - _base_url = 'https://api.wolframalpha.com/v2/query' + _base_url = "https://api.wolframalpha.com/v2/query" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - appid = self.runtime.credentials.get('appid', '') + return self.create_text_message("Please input query") + appid = self.runtime.credentials.get("appid", "") if not appid: - raise ToolProviderCredentialValidationError('Please input appid') - - params = { - 'appid': appid, - 'input': query, - 'includepodid': 'Result', - 'format': 'plaintext', - 'output': 'json' - } + raise ToolProviderCredentialValidationError("Please input appid") + + params = {"appid": appid, "input": query, "includepodid": "Result", "format": "plaintext", "output": "json"} finished = False result = None @@ -45,34 +40,33 @@ def _invoke(self, response_data = response.json() except Exception as e: raise ToolInvokeError(str(e)) - - if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: - query_result = response_data.get('queryresult', {}) - if query_result.get('error'): - if 'msg' in query_result['error']: - if query_result['error']['msg'] == 'Invalid appid': - raise ToolProviderCredentialValidationError('Invalid appid') - raise ToolInvokeError('Failed to invoke tool') - - if 'didyoumeans' in response_data['queryresult']: + + if "success" not in response_data["queryresult"] or response_data["queryresult"]["success"] != True: + query_result = response_data.get("queryresult", {}) + if query_result.get("error"): + if "msg" in query_result["error"]: + if query_result["error"]["msg"] == "Invalid appid": + raise ToolProviderCredentialValidationError("Invalid appid") + raise ToolInvokeError("Failed to invoke tool") + + if "didyoumeans" in response_data["queryresult"]: # get the most likely interpretation - query = '' + query = "" max_score = 0 - for didyoumean in response_data['queryresult']['didyoumeans']: - if float(didyoumean['score']) > max_score: - query = didyoumean['val'] - max_score = float(didyoumean['score']) + for didyoumean in response_data["queryresult"]["didyoumeans"]: + if float(didyoumean["score"]) > max_score: + query = didyoumean["val"] + max_score = float(didyoumean["score"]) - params['input'] = query + params["input"] = query else: finished = True - if 'souces' in response_data['queryresult']: - return self.create_link_message(response_data['queryresult']['sources']['url']) - elif 'pods' in response_data['queryresult']: - result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + if "souces" in response_data["queryresult"]: + return self.create_link_message(response_data["queryresult"]["sources"]["url"]) + elif "pods" in response_data["queryresult"]: + result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] if not finished or not result: - return self.create_text_message('No result found') + return self.create_text_message("No result found") return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index ef1aac7ff272c5..7be288b5387f34 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -13,11 +13,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "1+2+....+111", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea8940082..f044fbe5404b0a 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -10,27 +10,28 @@ class YahooFinanceAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - symbol = tool_parameters.get('symbol', '') + symbol = tool_parameters.get("symbol", "") if not symbol: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") stock_data = download(symbol, start=time_range[0], end=time_range[1]) max_segments = min(15, len(stock_data)) @@ -41,30 +42,29 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) segment_data = stock_data.iloc[start_idx:end_idx] segment_summary = { - 'Start Date': segment_data.index[0], - 'End Date': segment_data.index[-1], - 'Average Close': segment_data['Close'].mean(), - 'Average Volume': segment_data['Volume'].mean(), - 'Average Open': segment_data['Open'].mean(), - 'Average High': segment_data['High'].mean(), - 'Average Low': segment_data['Low'].mean(), - 'Average Adj Close': segment_data['Adj Close'].mean(), - 'Max Close': segment_data['Close'].max(), - 'Min Close': segment_data['Close'].min(), - 'Max Volume': segment_data['Volume'].max(), - 'Min Volume': segment_data['Volume'].min(), - 'Max Open': segment_data['Open'].max(), - 'Min Open': segment_data['Open'].min(), - 'Max High': segment_data['High'].max(), - 'Min High': segment_data['High'].min(), + "Start Date": segment_data.index[0], + "End Date": segment_data.index[-1], + "Average Close": segment_data["Close"].mean(), + "Average Volume": segment_data["Volume"].mean(), + "Average Open": segment_data["Open"].mean(), + "Average High": segment_data["High"].mean(), + "Average Low": segment_data["Low"].mean(), + "Average Adj Close": segment_data["Adj Close"].mean(), + "Max Close": segment_data["Close"].max(), + "Min Close": segment_data["Close"].min(), + "Max Volume": segment_data["Volume"].max(), + "Min Volume": segment_data["Volume"].min(), + "Max Open": segment_data["Open"].max(), + "Min Open": segment_data["Open"].min(), + "Max High": segment_data["High"].max(), + "Min High": segment_data["High"].min(), } - + summary_data.append(segment_summary) summary_df = pd.DataFrame(summary_data) - + try: return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file + return self.create_text_message("There is a internet connection problem. Please try again later.") diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3ec1de..ff820430f9f366 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,40 +8,39 @@ class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - ''' - invoke tools - ''' - - query = tool_parameters.get('symbol', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.run(ticker=query, user_id=user_id) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') + return self.create_text_message("There is a internet connection problem. Please try again later.") def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: company = yfinance.Ticker(ticker) try: if company.isin is None: - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") except (HTTPError, ReadTimeout, ConnectionError): - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") links = [] try: - links = [n['link'] for n in company.news if n['type'] == 'STORY'] + links = [n["link"] for n in company.news if n["type"] == "STORY"] except (HTTPError, ReadTimeout, ConnectionError): if not links: - return self.create_text_message(f'There is nothing about {ticker} ticker') + return self.create_text_message(f"There is nothing about {ticker} ticker") if not links: - return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') - - result = '\n\n'.join([ - self.get_url(link) for link in links - ]) + return self.create_text_message(f"No news found for company that searched with {ticker} ticker.") + + result = "\n\n".join([self.get_url(link) for link in links]) return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25ba93..dfc7e460473c33 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -8,19 +8,20 @@ class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('symbol', '') + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.create_text_message(self.run(ticker=query)) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - + return self.create_text_message("There is a internet connection problem. Please try again later.") + def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index 96dbc6c3d0d8e9..8d82084e769703 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "ticker": "MSFT", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 7a9b9fce4a921f..95dec2eac9a752 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -8,60 +8,67 @@ class YoutubeVideosAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - channel = tool_parameters.get('channel', '') + channel = tool_parameters.get("channel", "") if not channel: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") - if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: - return self.create_text_message('Please input api key') + if "google_api_key" not in self.runtime.credentials or not self.runtime.credentials["google_api_key"]: + return self.create_text_message("Please input api key") - youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + youtube = build("youtube", "v3", developerKey=self.runtime.credentials["google_api_key"]) # try to get channel id - search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() - channel_id = search_results['items'][0]['id']['channelId'] + search_results = youtube.search().list(q=channel, type="channel", order="relevance", part="id").execute() + channel_id = search_results["items"][0]["id"]["channelId"] start_date, end_date = time_range - start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') - end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") # get videos - time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, - publishedBefore=end_date - ).execute() + time_range_videos = ( + youtube.search() + .list( + part="snippet", + channelId=channel_id, + order="date", + type="video", + publishedAfter=start_date, + publishedBefore=end_date, + ) + .execute() + ) def extract_video_data(video_list): data = [] - for video in video_list['items']: - video_id = video['id']['videoId'] - video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() - title = video_info['items'][0]['snippet']['title'] - views = video_info['items'][0]['statistics']['viewCount'] - data.append({'Title': title, 'Views': views}) + for video in video_list["items"]: + video_id = video["id"]["videoId"] + video_info = youtube.videos().list(part="snippet,statistics", id=video_id).execute() + title = video_info["items"][0]["snippet"]["title"] + views = video_info["items"][0]["statistics"]["viewCount"] + data.append({"Title": title, "Views": views}) return data summary = extract_video_data(time_range_videos) - + return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 83a4fccb3247d0..aad876491c85dc 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -11,7 +11,7 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "channel": "TOKYO GIRLS COLLECTION", "start_date": "2020-01-01", @@ -20,4 +20,3 @@ def _validate_credentials(self, credentials: dict) -> None: ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index bcf41c90edbfcd..6b64dd1b4e3201 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -22,34 +22,36 @@ def __init__(self, **data: Any) -> None: if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) return - + # load provider yaml - provider = self.__class__.__module__.split('.')[-1] - yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + provider = self.__class__.__module__.split(".")[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml") try: provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') + raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: + if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name - for credential_name in provider_yaml['credentials_for_provider']: - provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name + for credential_name in provider_yaml["credentials_for_provider"]: + provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - super().__init__(**{ - 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', None), - }) + super().__init__( + **{ + "identity": provider_yaml["identity"], + "credentials_schema": provider_yaml.get("credentials_for_provider", None), + } + ) def _get_builtin_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ if self.tools: return self.tools - + provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") # get all the yaml files in the tool path @@ -62,155 +64,161 @@ def _get_builtin_tools(self) -> list[Tool]: # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) + module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py" + ), + parent_type=BuiltinTool, + ) tool["identity"]["provider"] = provider tools.append(assistant_tool_class(**tool)) self.tools = tools return tools - + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ if not self.credentials_schema: return {} - + return self.credentials_schema.copy() def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ return self._get_builtin_tools() - + def get_tool(self, tool_name: str) -> Tool: """ - returns the tool that the provider can provide + returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def need_credentials(self) -> bool: """ - returns whether the provider needs credentials + returns whether the provider needs credentials - :return: whether the provider needs credentials + :return: whether the provider needs credentials """ - return self.credentials_schema is not None and \ - len(self.credentials_schema) != 0 + return self.credentials_schema is not None and len(self.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN @property def tool_labels(self) -> list[str]: """ - returns the labels of the provider + returns the labels of the provider - :return: labels of the provider + :return: labels of the provider """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] def _get_tool_labels(self) -> list[ToolLabelEnum]: """ - returns the labels of the provider + returns the labels of the provider """ return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + default_value = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) tool_parameters[parameter] = default_value - + def validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) @@ -221,9 +229,9 @@ def validate_credentials(self, credentials: dict[str, Any]) -> None: @abstractmethod def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ef1ace9c7c31e7..f4008eedce8154 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -21,162 +21,174 @@ class ToolProviderController(BaseModel, ABC): def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ return self.credentials_schema.copy() - + @abstractmethod def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ pass @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ - returns a tool that the provider can provide + returns a tool that the provider can provide - :return: tool + :return: tool """ pass def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ - validate the format of the credentials of the provider and set the default value if needed + validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool + :param credentials: the credentials of the tool """ credentials_schema = self.credentials_schema if credentials_schema is None: return - + credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] for credential_name in credentials: if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.identity.name}" + ) + # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if ( + credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT + ): if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + options = credential_schema.options if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if ( + credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT + ): default_value = str(default_value) credentials[credential_name] = default_value - \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f14abac76777da..25eaf6a66a6658 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -30,29 +30,25 @@ class WorkflowToolProviderController(ToolProviderController): provider_id: str @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app if not app: - raise ValueError('app not found') - - controller = WorkflowToolProviderController(**{ - 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', - 'name': db_provider.label, - 'label': { - 'en_US': db_provider.label, - 'zh_Hans': db_provider.label + raise ValueError("app not found") + + controller = WorkflowToolProviderController( + **{ + "identity": { + "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", + "name": db_provider.label, + "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': {}, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": {}, + "provider_id": db_provider.id or "", + } + ) # init tools @@ -66,25 +62,23 @@ def provider_type(self) -> ToolProviderType: def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool """ - workflow: Workflow = db.session.query(Workflow).filter( - Workflow.app_id == db_provider.app_id, - Workflow.version == db_provider.version - ).first() + workflow: Workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .first() + ) if not workflow: - raise ValueError('workflow not found') + raise ValueError("workflow not found") # fetch start node graph: dict = workflow.graph_dict features_dict: dict = workflow.features_dict - features = WorkflowAppConfigManager.convert_features( - config_dict=features_dict, - app_mode=AppMode.WORKFLOW - ) + features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) @@ -101,51 +95,34 @@ def fetch_workflow_variable(variable_name: str) -> VariableEntity: parameter_type = None options = None if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: - raise ValueError(f'unsupported variable type {variable.type}') + raise ValueError(f"unsupported variable type {variable.type}") parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in variable.options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in variable.options ] workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=variable.label, - zh_Hans=variable.label - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=variable.label, zh_Hans=variable.label), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=parameter_type, form=parameter.form, llm_description=parameter.description, required=variable.required, options=options, - default=variable.default + default=variable.default, ) ) elif features.file_upload: workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=parameter.name, - zh_Hans=parameter.name - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=ToolParameter.ToolParameterType.FILE, llm_description=parameter.description, required=False, @@ -153,53 +130,51 @@ def fetch_workflow_variable(variable_name: str) -> VariableEntity: ) ) else: - raise ValueError('variable not found') + raise ValueError("variable not found") return WorkflowTool( identity=ToolIdentity( - author=user.name if user else '', + author=user.name if user else "", name=db_provider.name, - label=I18nObject( - en_US=db_provider.label, - zh_Hans=db_provider.label - ), + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), provider=self.provider_id, icon=db_provider.icon, ), description=ToolDescription( - human=I18nObject( - en_US=db_provider.description, - zh_Hans=db_provider.description - ), + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), llm=db_provider.description, ), parameters=workflow_tool_parameters, is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ - 'app': app, - 'workflow': workflow, + "app": app, + "workflow": workflow, }, version=db_provider.version, workflow_call_depth=0, - label=db_provider.label + label=db_provider.label, ) def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, - ).first() + db_providers: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() + ) if not db_providers: return [] @@ -210,10 +185,10 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 38f10032e2282b..bf336b48f304e8 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -12,8 +12,8 @@ from core.tools.tool.tool import Tool API_TOOL_DEFAULT_TIMEOUT = ( - int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), - int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) + int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), + int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), ) @@ -24,31 +24,32 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, - runtime=Tool.Runtime(**runtime) + runtime=Tool.Runtime(**runtime), ) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], - format_only: bool = False) -> str: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str: """ - validate the credentials for Api tool + validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: - return '' + return "" response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response @@ -61,30 +62,30 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} credentials = self.runtime.credentials or {} - if 'auth_type' not in credentials: - raise ToolProviderCredentialValidationError('Missing auth_type') + if "auth_type" not in credentials: + raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials['auth_type'] == 'api_key': - api_key_header = 'api_key' + if credentials["auth_type"] == "api_key": + api_key_header = "api_key" - if 'api_key_header' in credentials: - api_key_header = credentials['api_key_header'] + if "api_key_header" in credentials: + api_key_header = credentials["api_key_header"] - if 'api_key_value' not in credentials: - raise ToolProviderCredentialValidationError('Missing api_key_value') - elif not isinstance(credentials['api_key_value'], str): - raise ToolProviderCredentialValidationError('api_key_value must be a string') + if "api_key_value" not in credentials: + raise ToolProviderCredentialValidationError("Missing api_key_value") + elif not isinstance(credentials["api_key_value"], str): + raise ToolProviderCredentialValidationError("api_key_value must be a string") - if 'api_key_header_prefix' in credentials: - api_key_header_prefix = credentials['api_key_header_prefix'] - if api_key_header_prefix == 'basic' and credentials['api_key_value']: - credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' - elif api_key_header_prefix == 'bearer' and credentials['api_key_value']: - credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' - elif api_key_header_prefix == 'custom': + if "api_key_header_prefix" in credentials: + api_key_header_prefix = credentials["api_key_header_prefix"] + if api_key_header_prefix == "basic" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials['api_key_value'] + headers[api_key_header] = credentials["api_key_value"] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] for parameter in needed_parameters: @@ -98,13 +99,13 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: def validate_and_parse_response(self, response: httpx.Response) -> str: """ - validate the response + validate the response """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return 'Empty response from the tool, please check your parameters and try again.' + return "Empty response from the tool, please check your parameters and try again." try: response = response.json() try: @@ -114,21 +115,22 @@ def validate_and_parse_response(self, response: httpx.Response) -> str: except Exception as e: return response.text else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") @staticmethod def get_parameter_value(parameter, parameters): - if parameter['name'] in parameters: - return parameters[parameter['name']] - elif parameter.get('required', False): + if parameter["name"] in parameters: + return parameters[parameter["name"]] + elif parameter.get("required", False): raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") else: - return (parameter.get('schema', {}) or {}).get('default', '') + return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request(self, url: str, method: str, headers: dict[str, Any], - parameters: dict[str, Any]) -> httpx.Response: + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: """ - do http request depending on api bundle + do http request depending on api bundle """ method = method.lower() @@ -138,29 +140,30 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], cookies = {} # check parameters - for parameter in self.api_bundle.openapi.get('parameters', []): + for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) - if parameter['in'] == 'path': - path_params[parameter['name']] = value + if parameter["in"] == "path": + path_params[parameter["name"]] = value - elif parameter['in'] == 'query': - if value !='': params[parameter['name']] = value + elif parameter["in"] == "query": + if value != "": + params[parameter["name"]] = value - elif parameter['in'] == 'cookie': - cookies[parameter['name']] = value + elif parameter["in"] == "cookie": + cookies[parameter["name"]] = value - elif parameter['in'] == 'header': - headers[parameter['name']] = value + elif parameter["in"] == "header": + headers[parameter["name"]] = value # check if there is a request body and handle it - if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body - if 'content' in self.api_bundle.openapi['requestBody']: - for content_type in self.api_bundle.openapi['requestBody']['content']: - headers['Content-Type'] = content_type - body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "content" in self.api_bundle.openapi["requestBody"]: + for content_type in self.api_bundle.openapi["requestBody"]["content"]: + headers["Content-Type"] = content_type + body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: # convert type @@ -169,63 +172,71 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" ) - elif 'default' in property: - body[name] = property['default'] + elif "default" in property: + body[name] = property["default"] else: body[name] = None break # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', f'{value}') + url = url.replace(f"{{{name}}}", f"{value}") # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored - if 'Content-Type' in headers: - if headers['Content-Type'] == 'application/json': + if "Content-Type" in headers: + if headers["Content-Type"] == "application/json": body = json.dumps(body) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + elif headers["Content-Type"] == "application/x-www-form-urlencoded": body = urlencode(body) else: body = body - if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body, - timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) + if method in ("get", "head", "post", "put", "delete", "patch"): + response = getattr(ssrf_proxy, method)( + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) return response else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") - def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], - max_recursive=10) -> Any: + def _convert_body_property_any_of( + self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 + ) -> Any: if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: try: - if 'type' in option: + if "type" in option: # Attempt to convert the value based on the type. - if option['type'] == 'integer' or option['type'] == 'int': + if option["type"] == "integer" or option["type"] == "int": return int(value) - elif option['type'] == 'number': - if '.' in str(value): + elif option["type"] == "number": + if "." in str(value): return float(value) else: return int(value) - elif option['type'] == 'string': + elif option["type"] == "string": return str(value) - elif option['type'] == 'boolean': - if str(value).lower() in ['true', '1']: + elif option["type"] == "boolean": + if str(value).lower() in ["true", "1"]: return True - elif str(value).lower() in ['false', '0']: + elif str(value).lower() in ["false", "0"]: return False else: continue # Not a boolean, try next option - elif option['type'] == 'null' and not value: + elif option["type"] == "null" and not value: return None else: continue # Unsupported type, try next option - elif 'anyOf' in option and isinstance(option['anyOf'], list): + elif "anyOf" in option and isinstance(option["anyOf"], list): # Recursive call to handle nested anyOf - return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) except ValueError: continue # Conversion failed, try next option # If no option succeeded, you might want to return the value as is or raise an error @@ -233,23 +244,23 @@ def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, an def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: try: - if 'type' in property: - if property['type'] == 'integer' or property['type'] == 'int': + if "type" in property: + if property["type"] == "integer" or property["type"] == "int": return int(value) - elif property['type'] == 'number': + elif property["type"] == "number": # check if it is a float - if '.' in str(value): + if "." in str(value): return float(value) else: return int(value) - elif property['type'] == 'string': + elif property["type"] == "string": return str(value) - elif property['type'] == 'boolean': + elif property["type"] == "boolean": return bool(value) - elif property['type'] == 'null': + elif property["type"] == "null": if value is None: return None - elif property['type'] == 'object' or property['type'] == 'array': + elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: # an array str like '[1,2]' also can convert to list [1,2] through json.loads @@ -264,8 +275,8 @@ def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> A return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") - elif 'anyOf' in property and isinstance(property['anyOf'], list): - return self._convert_body_property_any_of(property, value, property['anyOf']) + elif "anyOf" in property and isinstance(property["anyOf"], list): + return self._convert_body_property_any_of(property, value, property["anyOf"]) except ValueError as e: return value diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index ad7a88838b1a24..8edaf7c0e6a5f5 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType @@ -16,40 +15,38 @@ class BuiltinTool(Tool): """ - Builtin tool + Builtin tool - :param meta: the meta data of a tool call processing + :param meta: the meta data of a tool call processing """ - def invoke_model( - self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str] - ) -> LLMResult: + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ - invoke model + invoke model - :param model_config: the model config - :param prompt_messages: the prompt messages - :param stop: the stop words - :return: the model result + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result """ # invoke model return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, - tool_type='builtin', + tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, ) - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN - + def get_max_tokens(self) -> int: """ - get max tokens + get max tokens - :param model_config: the model config - :return: the max tokens + :param model_config: the model config + :return: the max tokens """ return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, @@ -57,39 +54,34 @@ def get_max_tokens(self) -> int: def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ - get prompt tokens + get prompt tokens - :param prompt_messages: the prompt messages - :return: the tokens + :param prompt_messages: the prompt messages + :return: the tokens """ - return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id, - prompt_messages=prompt_messages - ) + return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=content) - ]) < max_tokens * 0.6: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content - + def get_prompt_tokens(content: str) -> int: - return self.get_prompt_tokens(prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ]) - + return self.get_prompt_tokens( + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)] + ) + def summarize(content: str) -> str: - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], + stop=[], + ) return summary.message.content - lines = content.split('\n') + lines = content.split("\n") new_lines = [] # split long line into multiple lines for i in range(len(lines)): @@ -100,8 +92,8 @@ def summarize(content: str) -> str: new_lines.append(line) elif get_prompt_tokens(line) > max_tokens * 0.7: while get_prompt_tokens(line) > max_tokens * 0.7: - new_lines.append(line[:int(max_tokens * 0.5)]) - line = line[int(max_tokens * 0.5):] + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] new_lines.append(line) else: new_lines.append(line) @@ -125,17 +117,15 @@ def summarize(content: str) -> str: summary = summarize(message) summaries.append(summary) - result = '\n'.join(summaries) + result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=result) - ]) > max_tokens * 0.7: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - + return result - + def get_url(self, url: str, user_agent: str = None) -> str: """ - get url + get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index d6ecc9257bc295..e76af6fe700332 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -14,14 +14,11 @@ from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -31,6 +28,7 @@ class DatasetMultiRetrieverToolInput(BaseModel): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " @@ -38,27 +36,26 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_provider_name: str reranking_model_name: str - @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", - tenant_id=tenant_id, - dataset_ids=dataset_ids, - **kwargs + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) def _run(self, query: str) -> str: threads = [] all_documents = [] for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'all_documents': all_documents, - 'hit_callbacks': self.hit_callbacks - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -69,7 +66,7 @@ def _run(self, query: str) -> str: tenant_id=self.tenant_id, provider=self.reranking_provider_name, model_type=ModelType.RERANK, - model=self.reranking_model_name + model=self.reranking_model_name, ) rerank_runner = RerankModelRunner(rerank_model_instance) @@ -80,62 +77,61 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 @@ -144,13 +140,18 @@ def _run(self, query: str) -> str: return str("\n".join(document_context_list)) - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler]): + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) if not dataset: return [] @@ -163,27 +164,29 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) if documents: all_documents.extend(documents) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) - - all_documents.extend(documents) \ No newline at end of file + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index 62e97a02306e58..dad8c773579099 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -9,6 +9,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): """Tool for querying a Dataset.""" + name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 220e4baa8555ea..f61458278eaf54 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -8,15 +7,12 @@ from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'reranking_mode': 'reranking_model', - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, } @@ -26,35 +22,34 @@ class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" + name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - **kwargs + **kwargs, ) def _run(self, query: str) -> str: - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == self.dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) if not dataset: - return '' + return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) @@ -63,27 +58,29 @@ def _run(self, query: str) -> str: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) else: documents = [] @@ -92,25 +89,26 @@ def _run(self, query: str) -> str: document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: @@ -118,36 +116,36 @@ def _run(self, query: str) -> str: resource_number = 1 for segment in sorted_segments: context = {} - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) - + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) \ No newline at end of file + return str("\n".join(document_context_list)) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index b5698ad23047ea..3c9295c493c470 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -20,13 +20,14 @@ class DatasetRetrieverTool(Tool): retrieval_tool: DatasetRetrieverBaseTool @staticmethod - def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: """ get dataset tool """ @@ -48,7 +49,7 @@ def get_dataset_tools(tenant_id: str, retrieve_config=retrieve_config, return_resource=return_resource, invoke_from=invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -58,13 +59,13 @@ def get_dataset_tools(tenant_id: str, for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( retrieval_tool=retrieval_tool, - identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')), + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), parameters=[], is_team_authorization=True, - description=ToolDescription( - human=I18nObject(en_US='', zh_Hans=''), - llm=retrieval_tool.description), - runtime=DatasetRetrieverTool.Runtime() + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + runtime=DatasetRetrieverTool.Runtime(), ) tools.append(tool) @@ -73,16 +74,18 @@ def get_dataset_tools(tenant_id: str, def get_runtime_parameters(self) -> list[ToolParameter]: return [ - ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + ), ] - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL @@ -90,9 +93,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke dataset retriever tool """ - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - return self.create_text_message(text='please input query') + return self.create_text_message(text="please input query") # invoke dataset retriever tool result = self.retrieval_tool._run(query=query) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index d990131b5fbbfd..ac3dc84db4e1f7 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -35,15 +35,16 @@ class Tool(BaseModel, ABC): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - @field_validator('parameters', mode='before') + @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: return v or [] class Runtime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing """ + def __init__(self, **data: Any): super().__init__(**data) if not self.runtime_parameters: @@ -63,14 +64,14 @@ def __init__(self, **data: Any): super().__init__(**data) class VARIABLE_KEY(Enum): - IMAGE = 'image' + IMAGE = "image" - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, @@ -82,22 +83,22 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ def load_variables(self, variables: ToolRuntimeVariablePool): """ - load variables from database + load variables from database - :param conversation_id: the conversation id + :param conversation_id: the conversation id """ self.variables = variables def set_image_variable(self, variable_name: str, image_key: str) -> None: """ - set an image variable + set an image variable """ if not self.variables: return @@ -106,7 +107,7 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: def set_text_variable(self, variable_name: str, text: str) -> None: """ - set a text variable + set a text variable """ if not self.variables: return @@ -115,10 +116,10 @@ def set_text_variable(self, variable_name: str, text: str) -> None: def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ - get a variable + get a variable - :param name: the name of the variable - :return: the variable + :param name: the name of the variable + :return: the variable """ if not self.variables: return None @@ -134,9 +135,9 @@ def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ - get the default image variable + get the default image variable - :return: the image variable + :return: the image variable """ if not self.variables: return None @@ -145,10 +146,10 @@ def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ - get a variable file + get a variable file - :param name: the name of the variable - :return: the variable file + :param name: the name of the variable + :return: the variable file """ variable = self.get_variable(name) if not variable: @@ -167,9 +168,9 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: def list_variables(self) -> list[ToolRuntimeVariable]: """ - list all variables + list all variables - :return: the variables + :return: the variables """ if not self.variables: return [] @@ -178,9 +179,9 @@ def list_variables(self) -> list[ToolRuntimeVariable]: def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ - list all image variables + list all image variables - :return: the image variables + :return: the image variables """ if not self.variables: return [] @@ -220,38 +221,42 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> result = deepcopy(tool_parameters) for parameter in self.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type) + result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( + tool_parameters[parameter.name], parameter.type + ) return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ - validate the credentials + validate the credentials - :param credentials: the credentials - :param parameters: the parameters + :param credentials: the credentials + :param parameters: the parameters """ pass def get_runtime_parameters(self) -> list[ToolParameter]: """ - get the runtime parameters + get the runtime parameters - interface for developer to dynamic change the parameters of a tool depends on the variables pool + interface for developer to dynamic change the parameters of a tool depends on the variables pool - :return: the runtime parameters + :return: the runtime parameters """ return self.parameters or [] def get_all_runtime_parameters(self) -> list[ToolParameter]: """ - get all runtime parameters + get all runtime parameters - :return: all runtime parameters + :return: all runtime parameters """ parameters = self.parameters or [] parameters = parameters.copy() @@ -281,67 +286,49 @@ def get_all_runtime_parameters(self) -> list[ToolParameter]: return parameters - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: """ - create an image message + create an image message - :param image: the url of the image - :return: the image message + :param image: the url of the image + :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, - message='', - meta={ - 'file_var': file_var - }, - save_as='') + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as="" + ) - def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ - create a link message + create a link message - :param link: the url of the link - :return: the link message + :param link: the url of the link + :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) - def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: """ - create a text message + create a text message - :param text: the text - :return: the text message + :param text: the text + :return: the text message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=text, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) - def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = "") -> ToolInvokeMessage: """ - create a blob message + create a blob message - :param blob: the blob - :return: the blob message + :param blob: the blob + :return: the blob message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=blob, meta=meta, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ - create a json message + create a json message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, - message=object - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 15e915628efbfa..ad0c7fc631fd08 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) + class WorkflowTool(Tool): workflow_app_id: str version: str @@ -25,11 +26,12 @@ class WorkflowTool(Tool): """ Workflow tool. """ + def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ return ToolProviderType.WORKFLOW @@ -37,7 +39,7 @@ def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke the tool + invoke the tool """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) @@ -46,33 +48,31 @@ def _invoke( tool_parameters, files = self._transform_args(tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() result = generator.generate( - app_model=app, - workflow=workflow, - user=self._get_user(user_id), - args={ - 'inputs': tool_parameters, - 'files': files - }, + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, - workflow_thread_pool_id=self.thread_pool_id + workflow_thread_pool_id=self.thread_pool_id, ) - data = result.get('data', {}) + data = result.get("data", {}) + + if data.get("error"): + raise Exception(data.get("error")) - if data.get('error'): - raise Exception(data.get('error')) - result = [] - outputs = data.get('outputs', {}) + outputs = data.get("outputs", {}) outputs, files = self._extract_files(outputs) for file in files: result.append(self.create_file_var_message(file)) - + result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs)) @@ -80,7 +80,7 @@ def _invoke( def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ - get the user by user id + get the user by user id """ user = db.session.query(EndUser).filter(EndUser.id == user_id).first() @@ -88,16 +88,16 @@ def _get_user(self, user_id: str) -> Union[EndUser, Account]: user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=deepcopy(self.identity), @@ -108,45 +108,44 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, - label=self.label + label=self.label, ) - + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ - get the workflow by app id and version + get the workflow by app id and version """ if not version: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version != 'draft' - ).order_by(Workflow.created_at.desc()).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .order_by(Workflow.created_at.desc()) + .first() + ) else: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version == version - ).first() + workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: - raise ValueError('workflow not found or not published') + raise ValueError("workflow not found or not published") return workflow - + def _get_app(self, app_id: str) -> App: """ - get the app by app id + get the app by app id """ app = db.session.query(App).filter(App.id == app_id).first() if not app: - raise ValueError('app not found') + raise ValueError("app not found") return app - + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ - transform the tool parameters + transform the tool parameters - :param tool_parameters: the tool parameters - :return: tool_parameters, files + :param tool_parameters: the tool parameters + :return: tool_parameters, files """ parameter_rules = self.get_all_runtime_parameters() parameters_result = {} @@ -159,15 +158,15 @@ def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: file_var_list = [FileVar(**f) for f in file] for file_var in file_var_list: file_dict = { - 'transfer_method': file_var.transfer_method.value, - 'type': file_var.type.value, + "transfer_method": file_var.transfer_method.value, + "type": file_var.type.value, } if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict['tool_file_id'] = file_var.related_id + file_dict["tool_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict['upload_file_id'] = file_var.related_id + file_dict["upload_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict['url'] = file_var.preview_url + file_dict["url"] = file_var.preview_url files.append(file_dict) except Exception as e: @@ -176,13 +175,13 @@ def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: parameters_result[parameter.name] = tool_parameters.get(parameter.name) return parameters_result, files - + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: """ - extract files from the result + extract files from the result - :param result: the result - :return: the result, files + :param result: the result + :return: the result, files """ files = [] result = {} @@ -190,7 +189,7 @@ def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: if isinstance(value, list): has_file = False for item in value: - if isinstance(item, dict) and item.get('__variant') == 'FileVar': + if isinstance(item, dict) and item.get("__variant") == "FileVar": try: files.append(FileVar(**item)) has_file = True @@ -201,4 +200,4 @@ def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: result[key] = value - return result, files \ No newline at end of file + return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 6c0e906628ab82..9a6a49d8f443b4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -33,12 +33,17 @@ class ToolEngine: """ Tool runtime engine take care of the tool executions. """ + @staticmethod def agent_invoke( - tool: Tool, tool_parameters: Union[str, dict], - user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -47,40 +52,30 @@ def agent_invoke( if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() or [] + parameter + for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } + tool_parameters = {parameters[0].name: tool_parameters} else: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=response, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id + messages=response, user_id=user_id, tenant_id=tenant_id, conversation_id=message.conversation_id ) # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary(response) # create message file message_files = ToolEngine._create_message_files( - tool_messages=binary_files, - agent_message=message, - invoke_from=invoke_from, - user_id=user_id + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) plain_text = ToolEngine._convert_tool_response_to_str(response) @@ -91,7 +86,7 @@ def agent_invoke( tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, - trace_manager=trace_manager + trace_manager=trace_manager, ) # transform tool invoke message to get LLM friendly message @@ -99,14 +94,10 @@ def agent_invoke( except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.identity.name}" agent_tool_callback.on_tool_error(e) - except ( - ToolParameterValidationError - ) as e: + except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) except ToolInvokeError as e: @@ -124,21 +115,20 @@ def agent_invoke( return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - thread_pool_id: Optional[str] = None - ) -> list[ToolInvokeMessage]: + def workflow_invoke( + tool: Tool, + tool_parameters: Mapping[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 @@ -159,21 +149,24 @@ def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + @staticmethod - def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: """ Invoke the tool with the given arguments. """ started_at = datetime.now(timezone.utc) - meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ - 'tool_name': tool.identity.name, - 'tool_provider': tool.identity.provider, - 'tool_provider_type': tool.tool_provider_type().value, - 'tool_parameters': deepcopy(tool.runtime.runtime_parameters), - 'tool_icon': tool.identity.icon - }) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.identity.name, + "tool_provider": tool.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.identity.icon, + }, + ) try: response = tool.invoke(user_id, tool_parameters) except Exception as e: @@ -184,20 +177,22 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ meta.time_cost = (ended_at - started_at).total_seconds() return meta, response - + @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ - result = '' + result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + elif ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." @@ -205,7 +200,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result += f"tool response: {response.message}." return result - + @staticmethod def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ @@ -214,52 +209,59 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): mimetype = None - if response.meta.get('mime_type'): - mimetype = response.meta.get('mime_type') + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") else: try: url = URL(response.message) extension = url.suffix - guess_type_result, _ = guess_type(f'a{extension}') + guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result except Exception: pass - + if not mimetype: - mimetype = 'image/jpeg' - - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'image/jpeg'), - url=response.message, - save_as=response.save_as, - )) + mimetype = "image/jpeg" + + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), + url=response.message, + save_as=response.save_as, + ) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream"), url=response.message, save_as=response.save_as, - )) + ) + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream") + if response.meta + else "octet/stream", + url=response.message, + save_as=response.save_as, + ) + ) return result - + @staticmethod def _create_message_files( - tool_messages: list[ToolInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str + tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str ) -> list[tuple[Any, str]]: """ Create message file @@ -270,29 +272,29 @@ def _create_message_files( result = [] for message in tool_messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' + file_type = "bin" + if "image" in message.mimetype: + file_type = "image" + elif "video" in message.mimetype: + file_type = "video" + elif "audio" in message.mimetype: + file_type = "audio" + elif "text" in message.mimetype: + file_type = "text" + elif "pdf" in message.mimetype: + file_type = "pdf" + elif "zip" in message.mimetype: + file_type = "archive" # ... message_file = MessageFile( message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', + belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), created_by=user_id, ) @@ -300,11 +302,8 @@ def _create_message_files( db.session.commit() db.session.refresh(message_file) - result.append(( - message_file.id, - message.save_as - )) + result.append((message_file.id, message.save_as)) db.session.close() - return result \ No newline at end of file + return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f9f7c7d78a7f28..ad3b9c7328ca74 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -27,24 +27,24 @@ def sign_file(tool_file_id: str, extension: str) -> str: sign file to get a temporary url """ base_url = dify_config.FILES_URL - file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -62,9 +62,9 @@ def create_file_by_raw( """ create file """ - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, file_binary) tool_file = ToolFile( @@ -90,10 +90,10 @@ def create_file_by_url( response = get(file_url) response.raise_for_status() blob = response.content - mimetype = guess_type(file_url)[0] or 'octet/stream' - extension = guess_extension(mimetype) or '.bin' + mimetype = guess_type(file_url)[0] or "octet/stream" + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, blob) tool_file = ToolFile( @@ -166,13 +166,12 @@ def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] else: tool_file_id = None - tool_file: ToolFile = ( db.session.query(ToolFile) .filter( @@ -216,4 +215,4 @@ def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generat # init tool_file_parser from core.file.tool_file_parser import tool_file_manager -tool_file_manager['manager'] = ToolFileManager +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 97788a7a07dfb0..2a5a2944ef8471 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -15,7 +15,7 @@ def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: """ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] return list(set(tool_labels)) - + @classmethod def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): """ @@ -26,20 +26,20 @@ def update_tool_labels(cls, controller: ToolProviderController, labels: list[str if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id == provider_id - ).delete() + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: - db.session.add(ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - )) + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) db.session.commit() @@ -53,12 +53,16 @@ def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) return [label.label_name for label in labels] @@ -75,22 +79,20 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ """ if not tool_providers: return {} - + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError('Unsupported tool type') - + raise ValueError("Unsupported tool type") + provider_ids = [controller.provider_id for controller in tool_providers] - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id.in_(provider_ids) - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) - tool_labels = { - label.tool_id: [] for label in labels - } + tool_labels = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels \ No newline at end of file + return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4778d79ed9e5c9..a3303797e16dbb 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -41,29 +41,29 @@ class ToolManager: @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ - get the builtin provider + get the builtin provider - :param provider: the name of the provider - :return: the provider + :param provider: the name of the provider + :return: the provider """ if len(cls._builtin_providers) == 0: # init the builtin providers cls.load_builtin_providers_cache() if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") return cls._builtin_providers[provider] @classmethod def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ - get the builtin tool + get the builtin tool - :param provider: the name of the provider - :param tool_name: the name of the tool + :param provider: the name of the provider + :param tool_name: the name of the tool - :return: the provider, the tool + :return: the provider, the tool """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) @@ -71,67 +71,76 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: return tool @classmethod - def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool( + cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool + get the tool - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": return cls.get_builtin_tool(provider_id, tool_name) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) return api_provider.get_tool(tool_name) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def get_tool_runtime(cls, provider_type: str, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool_runtime( + cls, + provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool runtime + get the tool runtime - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) # get credentials - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_id, + ) + .first() + ) if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials @@ -140,17 +149,19 @@ def get_tool_runtime(cls, provider_type: str, decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'runtime_parameters': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "runtime_parameters": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) @@ -158,40 +169,43 @@ def get_tool_runtime(cls, provider_type: str, tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'workflow': - workflow_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "workflow": + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - controller = ToolTransformService.workflow_provider_to_controller( - db_provider=workflow_provider - ) + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: """ - init runtime parameter + init runtime parameter """ parameter_value = parameters.get(parameter_rule.name) if not parameter_value and parameter_value != 0: @@ -205,14 +219,17 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" + ) return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_agent_tool_runtime( + cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER + ) -> Tool: """ - get the agent tool runtime + get the agent tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, @@ -220,7 +237,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo tool_name=agent_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT + tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -240,7 +257,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, - identity_id=f'AGENT.{app_id}' + identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) @@ -248,9 +265,16 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: """ - get the workflow tool runtime + get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, @@ -258,7 +282,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo tool_name=workflow_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW + tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -275,7 +299,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, - identity_id=f'WORKFLOW.{app_id}.{node_id}' + identity_id=f"WORKFLOW.{app_id}.{node_id}", ) if runtime_parameters: @@ -287,24 +311,30 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo @classmethod def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ - get the absolute path of the icon of the builtin provider + get the absolute path of the icon of the builtin provider - :param provider: the name of the provider + :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon + :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', - provider_controller.identity.icon) + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider, + "_assets", + provider_controller.identity.icon, + ) # check if the icon exists if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or 'application/octet-stream' + mime_type = mime_type or "application/octet-stream" return absolute_path, mime_type @@ -325,23 +355,25 @@ def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None @classmethod def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ - list all the builtin providers + list all the builtin providers """ - for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider.startswith('__'): + for provider in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): + if provider.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): - if provider.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider)): + if provider.startswith("__"): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), - parent_type=BuiltinToolProviderController) + module_name=f"core.tools.provider.builtin.{provider}.{provider}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "provider", "builtin", provider, f"{provider}.py" + ), + parent_type=BuiltinToolProviderController, + ) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): @@ -349,7 +381,7 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non yield provider except Exception as e: - logger.error(f'load builtin provider {provider} error: {e}') + logger.error(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -367,11 +399,11 @@ def clear_builtin_providers_cache(cls): @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ - get the tool label + get the tool label - :param tool_name: the name of the tool + :param tool_name: the name of the tool - :return: the label of the tool + :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers @@ -383,75 +415,78 @@ def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: + def user_list_providers( + cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral + ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} filters = [] if not typ: - filters.extend(['builtin', 'api', 'workflow']) + filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - if 'builtin' in filters: - + if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers() # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None + (x for x in db_builtin_providers if x.provider == provider), None ) # append builtin providers for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, - data=provider, - name_func=lambda x: x.identity.name + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name, ): continue user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + decrypt_credentials=False, ) result_providers[provider.identity.name] = user_provider # get db api providers - if 'api' in filters: - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) - api_provider_controllers = [{ - 'provider': provider, - 'controller': ToolTransformService.api_provider_to_controller(provider) - } for provider in db_api_providers] + api_provider_controllers = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] # get labels - labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller['controller'], - db_provider=api_provider_controller['provider'], + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], decrypt_credentials=False, - labels=labels.get(api_provider_controller['controller'].provider_id, []) + labels=labels.get(api_provider_controller["controller"].provider_id, []), ) - result_providers[f'api_provider.{user_provider.name}'] = user_provider + result_providers[f"api_provider.{user_provider.name}"] = user_provider - if 'workflow' in filters: + if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ - filter(WorkflowToolProvider.tenant_id == tenant_id).all() + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) workflow_provider_controllers = [] for provider in workflow_providers: @@ -470,32 +505,36 @@ def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProvider provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) - result_providers[f'workflow_provider.{user_provider.name}'] = user_provider + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod - def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ - get the api provider + get the api provider - :param provider_name: the name of the provider + :param provider_name: the name of the provider - :return: the provider controller, the credentials + :return: the provider controller, the credentials """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) @@ -504,18 +543,22 @@ def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ - get api provider + get api provider """ """ get tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f"you have not added provider {provider}") try: credentials = json.loads(provider.credentials_str) or {} @@ -524,7 +567,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -535,62 +578,62 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: try: icon = json.loads(provider.icon) except: - icon = { - "background": "#252525", - "content": "\ud83d\ude01" - } + icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder({ - 'schema_type': provider.schema_type, - 'schema': provider.schema, - 'tools': provider.tools, - 'icon': icon, - 'description': provider.description, - 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer, - 'labels': labels, - }) + return jsonable_encoder( + { + "schema_type": provider.schema_type, + "schema": provider.schema, + "tools": provider.tools, + "icon": icon, + "description": provider.description, + "credentials": masked_credentials, + "privacy_policy": provider.privacy_policy, + "custom_disclaimer": provider.custom_disclaimer, + "labels": labels, + } + ) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: """ - get the tool icon + get the tool icon - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: """ provider_type = provider_type provider_id = provider_id - if provider_type == 'builtin': - return (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon") - elif provider_type == 'api': + if provider_type == "builtin": + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.id == provider_id - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) return json.loads(provider.icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - elif provider_type == 'workflow': - provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return {"background": "#252525", "content": "\ud83d\ude01"} + elif provider_type == "workflow": + provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(provider.icon) else: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b213879e960b14..83600d21c13dc2 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -56,12 +56,13 @@ def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]: if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: - credentials[field_name] = \ - credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) + \ - credentials[field_name][-2:] + credentials[field_name] = ( + credentials[field_name][:2] + + "*" * (len(credentials[field_name]) - 4) + + credentials[field_name][-2:] + ) else: - credentials[field_name] = '*' * len(credentials[field_name]) + credentials[field_name] = "*" * len(credentials[field_name]) return credentials @@ -72,9 +73,9 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() if cached_credentials: @@ -95,16 +96,18 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager """ + tenant_id: str tool_runtime: Tool provider_name: str @@ -152,15 +155,19 @@ def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: current_parameters = self._merge_parameters() for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = \ - parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) + \ - parameters[parameter.name][-2:] + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) else: - parameters[parameter.name] = '*' * len(parameters[parameter.name]) + parameters[parameter.name] = "*" * len(parameters[parameter.name]) return parameters @@ -176,7 +183,10 @@ def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: parameters = self._deep_copy(parameters) for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted @@ -191,10 +201,10 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cached_parameters = cache.get() if cached_parameters: @@ -205,7 +215,10 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: has_secret_input = False for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: try: has_secret_input = True @@ -221,9 +234,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cache.delete() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index e6b288868f47b0..7bb026a383a616 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -17,8 +17,9 @@ def tenant_access_token(self): redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) return res.get("tenant_access_token") - def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, - params: dict = None): + def _send_request( + self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None + ): headers = { "Content-Type": "application/json", "user-agent": "Dify", @@ -42,10 +43,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: } """ url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} res = self._send_request(url, require_token=False, payload=payload) return res @@ -76,11 +74,7 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: def write_document(self, document_id: str, content: str, position: str = "start") -> dict: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" - payload = { - "document_id": document_id, - "content": content, - "position": position - } + payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) return res.get("data") diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 23e7c0c2438367..c4983ebc65ea8d 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -10,10 +10,9 @@ class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], - user_id: str, - tenant_id: str, - conversation_id: str) -> list[ToolInvokeMessage]: + def transform_tool_invoke_messages( + cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str + ) -> list[ToolInvokeMessage]: """ Transform tool message and handle file download """ @@ -28,78 +27,88 @@ def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_url=message.message + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) except Exception as e: logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + ) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') + mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if isinstance(message.message, str): - message.message = message.message.encode('utf-8') + message.message = message.message.encode("utf-8") file = ToolFileManager.create_file_by_raw( - user_id=user_id, tenant_id=tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, - mimetype=mimetype + mimetype=mimetype, ) url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + if "image" in mimetype: + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get('file_var') + file_var = message.meta.get("file_var") if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: result.append(message) diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e8ef478237d6e..4e226810d6ac90 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,7 +1,7 @@ """ - For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. +Therefore, a model manager is needed to list/invoke/validate models. """ import json @@ -27,52 +27,49 @@ class InvokeModelError(Exception): pass + class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, ) -> int: """ - get max llm context tokens of the model + get max llm context tokens of the model """ model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if not schema: - raise InvokeModelError('No model schema found') + raise InvokeModelError("No model schema found") max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - + return max_tokens @staticmethod - def calculate_tokens( - tenant_id: str, - prompt_messages: list[PromptMessage] - ) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: """ - calculate tokens from prompt messages and model parameters + calculate tokens from prompt messages and model parameters """ # get model instance model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -80,9 +77,7 @@ def calculate_tokens( @staticmethod def invoke( - user_id: str, tenant_id: str, - tool_type: str, tool_name: str, - prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context @@ -103,15 +98,16 @@ def invoke( model_manager = ModelManager() # get model instance model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { - 'temperature': 0.8, - 'top_p': 0.8, + "temperature": 0.8, + "top_p": 0.8, } # create tool model invoke @@ -123,14 +119,14 @@ def invoke( tool_name=tool_name, model_parameters=json.dumps(model_parameters), prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response='', + model_response="", prompt_tokens=prompt_tokens, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", ) db.session.add(tool_model_invoke) @@ -140,20 +136,24 @@ def invoke( response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=[], stop=[], stream=False, user=user_id, callbacks=[] + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: - raise InvokeModelError(f'Invoke rate limit error: {e}') + raise InvokeModelError(f"Invoke rate limit error: {e}") except InvokeBadRequestError as e: - raise InvokeModelError(f'Invoke bad request error: {e}') + raise InvokeModelError(f"Invoke bad request error: {e}") except InvokeConnectionError as e: - raise InvokeModelError(f'Invoke connection error: {e}') + raise InvokeModelError(f"Invoke connection error: {e}") except InvokeAuthorizationError as e: - raise InvokeModelError('Invoke authorization error') + raise InvokeModelError("Invoke authorization error") except InvokeServerUnavailableError as e: - raise InvokeModelError(f'Invoke server unavailable error: {e}') + raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: - raise InvokeModelError(f'Invoke error: {e}') + raise InvokeModelError(f"Invoke error: {e}") # update tool model invoke tool_model_invoke.model_response = response.message.content diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f711f7c9f3c2e8..654c9acaf934b4 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,3 @@ - import re import uuid from json import dumps as json_dumps @@ -16,54 +15,56 @@ class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} # set description to extra_info - extra_info['description'] = openapi['info'].get('description', '') + extra_info["description"] = openapi["info"].get("description", "") - if len(openapi['servers']) == 0: - raise ToolProviderNotFoundError('No server found in the openapi yaml.') + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi['servers'][0]['url'] + server_url = openapi["servers"][0]["url"] # list all interfaces interfaces = [] - for path, path_item in openapi['paths'].items(): - methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] for method in methods: if method in path_item: - interfaces.append({ - 'path': path, - 'method': method, - 'operation': path_item[method], - }) + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) # get all parameters bundles = [] for interface in interfaces: # convert parameters parameters = [] - if 'parameters' in interface['operation']: - for parameter in interface['operation']['parameters']: + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( - name=parameter['name'], - label=I18nObject( - en_US=parameter['name'], - zh_Hans=parameter['name'] - ), + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), human_description=I18nObject( - en_US=parameter.get('description', ''), - zh_Hans=parameter.get('description', '') + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, - required=parameter.get('required', False), + required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, - llm_description=parameter.get('description'), - default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, ) - + # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: @@ -72,44 +73,40 @@ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning parameters.append(tool_parameter) # create tool bundle # check if there is a request body - if 'requestBody' in interface['operation']: - request_body = interface['operation']['requestBody'] - if 'content' in request_body: - for content_type, content in request_body['content'].items(): + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): # if there is a reference, get the reference and overwrite the content - if 'schema' not in content: + if "schema" not in content: continue - if '$ref' in content['schema']: + if "$ref" in content["schema"]: # get the reference root = openapi - reference = content['schema']['$ref'].split('/')[1:] + reference = content["schema"]["$ref"].split("/")[1:] for ref in reference: root = root[ref] # overwrite the content - interface['operation']['requestBody']['content'][content_type]['schema'] = root + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if 'schema' in interface['operation']['requestBody']['content'][content_type]: - body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): tool = ToolParameter( name=name, - label=I18nObject( - en_US=name, - zh_Hans=name - ), + label=I18nObject(en_US=name, zh_Hans=name), human_description=I18nObject( - en_US=property.get('description', ''), - zh_Hans=property.get('description', '') + en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get('description', ''), - default=property.get('default', None), + llm_description=property.get("description", ""), + default=property.get("default", None), ) # check if there is a type @@ -127,172 +124,176 @@ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning parameters_count[parameter.name] += 1 for name, count in parameters_count.items(): if count > 1: - warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." # check if there is a operation id, use $path_$method as operation id if not - if 'operationId' not in interface['operation']: + if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = interface['path'] - if interface['path'].startswith('/'): - path = interface['path'][1:] + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - - interface['operation']['operationId'] = f'{path}_{interface["method"]}' - - bundles.append(ApiToolBundle( - server_url=server_url + interface['path'], - method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation'].get('summary', None), - operation_id=interface['operation']['operationId'], - parameters=parameters, - author='', - icon=None, - openapi=interface['operation'], - )) + + interface["operation"]["operationId"] = f'{path}_{interface["method"]}' + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) return bundles - + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None - if 'type' in parameter: - typ = parameter['type'] - elif 'schema' in parameter and 'type' in parameter['schema']: - typ = parameter['schema']['type'] - - if typ == 'integer' or typ == 'number': + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ == "integer" or typ == "number": return ToolParameter.ToolParameterType.NUMBER - elif typ == 'boolean': + elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN - elif typ == 'string': + elif typ == "string": return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi yaml to tool bundle + parse openapi yaml to tool bundle - :param yaml: the yaml string - :return: the tool bundle + :param yaml: the yaml string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} openapi: dict = safe_load(yaml) if openapi is None: - raise ToolApiSchemaError('Invalid openapi yaml.') + raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - + @staticmethod def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: """ - parse swagger to openapi + parse swagger to openapi - :param swagger: the swagger dict - :return: the openapi dict + :param swagger: the swagger dict + :return: the openapi dict """ # convert swagger to openapi - info = swagger.get('info', { - 'title': 'Swagger', - 'description': 'Swagger', - 'version': '1.0.0' - }) + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get('servers', []) + servers = swagger.get("servers", []) if len(servers) == 0: - raise ToolApiSchemaError('No server found in the swagger yaml.') + raise ToolApiSchemaError("No server found in the swagger yaml.") openapi = { - 'openapi': '3.0.0', - 'info': { - 'title': info.get('title', 'Swagger'), - 'description': info.get('description', 'Swagger'), - 'version': info.get('version', '1.0.0') + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), }, - 'servers': swagger['servers'], - 'paths': {}, - 'components': { - 'schemas': {} - } + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, } # check paths - if 'paths' not in swagger or len(swagger['paths']) == 0: - raise ToolApiSchemaError('No paths found in the swagger yaml.') + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") # convert paths - for path, path_item in swagger['paths'].items(): - openapi['paths'][path] = {} + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} for method, operation in path_item.items(): - if 'operationId' not in operation: - raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') - - if ('summary' not in operation or len(operation['summary']) == 0) and \ - ('description' not in operation or len(operation['description']) == 0): - warning['missing_summary'] = f'No summary or description found in operation {method} {path}.' - - openapi['paths'][path][method] = { - 'operationId': operation['operationId'], - 'summary': operation.get('summary', ''), - 'description': operation.get('description', ''), - 'parameters': operation.get('parameters', []), - 'responses': operation.get('responses', {}), + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), } - if 'requestBody' in operation: - openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger['definitions'].items(): - openapi['components']['schemas'][name] = definition + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi plugin yaml to tool bundle + parse openapi plugin yaml to tool bundle - :param json: the json string - :return: the tool bundle + :param json: the json string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} try: openai_plugin = json_loads(json) - api = openai_plugin['api'] - api_url = api['url'] - api_type = api['type'] + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] except: - raise ToolProviderNotFoundError('Invalid openai plugin json.') - - if api_type != 'openapi': - raise ToolNotSupportedError('Only openapi is supported now.') - + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + # get openapi yaml - response = get(api_url, headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' - }, timeout=5) + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) if response.status_code != 200: - raise ToolProviderNotFoundError('cannot get openapi yaml from url.') - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) - + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: + def auto_parse_to_tool_bundle( + content: str, extra_info: dict = None, warning: dict = None + ) -> tuple[list[ApiToolBundle], str]: """ - auto parse to tool bundle + auto parse to tool bundle - :param content: the content - :return: tools bundle, schema_type + :param content: the content + :return: tools bundle, schema_type """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -301,7 +302,7 @@ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: di loaded_content = None json_error = None yaml_error = None - + try: loaded_content = json_loads(content) except JSONDecodeError as e: @@ -313,34 +314,46 @@ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: di except YAMLError as e: yaml_error = e if loaded_content is None: - raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}" + ) swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - + try: - openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.OPENAPI.value return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - + # openai parse error, fallback to swagger try: - converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.SWAGGER.value - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type except ToolApiSchemaError as e: swagger_error = e - + # swagger parse error, fallback to openai plugin try: - openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py index 6f88eeaa0a8a98..6f7610651cb5aa 100644 --- a/api/core/tools/utils/tool_parameter_converter.py +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -7,16 +7,18 @@ class ToolParameterConverter: @staticmethod def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - return 'string' + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + return "string" case ToolParameter.ToolParameterType.BOOLEAN: - return 'boolean' + return "boolean" case ToolParameter.ToolParameterType.NUMBER: - return 'number' + return "number" case _: raise ValueError(f"Unsupported parameter type {parameter_type}") @@ -26,11 +28,13 @@ def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: # convert tool parameter config to correct type try: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): if value is None: - return '' + return "" else: return value if isinstance(value, str) else str(value) @@ -41,9 +45,9 @@ def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: # Allowed YAML boolean value strings: https://yaml.org/type/bool.html # and also '0' for False and '1' for True match value.lower(): - case 'true' | 'yes' | 'y' | '1': + case "true" | "yes" | "y" | "1": return True - case 'false' | 'no' | 'n' | '0': + case "false" | "no" | "n" | "0": return False case _: return bool(value) @@ -53,8 +57,8 @@ def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: case ToolParameter.ToolParameterType.NUMBER: if isinstance(value, int) | isinstance(value, float): return value - elif isinstance(value, str) and value != '': - if '.' in value: + elif isinstance(value, str) and value != "": + if "." in value: return float(value) else: return int(value) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 150941924d761f..3639b5fff79b7f 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -32,7 +32,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] + return text[cursor : cursor + max_length] def get_url(url: str, user_agent: str = None) -> str: @@ -49,15 +49,15 @@ def get_url(url: str, user_agent: str = None) -> str: if response.status_code == 200: # check content-type - content_type = response.headers.get('Content-Type') + content_type = response.headers.get("Content-Type") if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() else: - content_disposition = response.headers.get('Content-Disposition', '') + content_disposition = response.headers.get("Content-Disposition", "") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) + extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] @@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str: # Detect encoding using chardet detected_encoding = chardet.detect(response.content) - encoding = detected_encoding['encoding'] + encoding = detected_encoding["encoding"] if encoding: try: content = response.content.decode(encoding) @@ -89,29 +89,29 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(content) - if not a['plain_text'] or not a['plain_text'].strip(): - return '' + if not a["plain_text"] or not a["plain_text"].strip(): + return "" res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], + title=a["title"], + authors=a["byline"], + publish_date=a["date"], top_image="", - text=a['plain_text'] if a['plain_text'] else "", + text=a["plain_text"] if a["plain_text"] else "", ) return res def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: f_html.write(html) f_html.close() html_path = f_html.name # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") with chdir(jsdir): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) @@ -129,7 +129,7 @@ def extract_using_readabilipy(html): "date": None, "content": None, "plain_content": None, - "plain_text": None + "plain_text": None, } # Populate article fields from readability fields where present if input_json: @@ -145,7 +145,7 @@ def extract_using_readabilipy(html): article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) return article_json @@ -158,6 +158,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -172,12 +173,14 @@ def chdir(path): def extract_text_blocks_as_plain_text(paragraph_html): # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') + soup = BeautifulSoup(paragraph_html, "html.parser") # Select all lists - list_elements = soup.find_all(['ul', 'ol']) + list_elements = soup.find_all(["ul", "ol"]) # Prefix text in all list items with "* " and make lists paragraphs for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) list_element.string = plain_items list_element.name = "p" # Select all text blocks @@ -204,7 +207,7 @@ def plain_text_leaf_node(element): def plain_content(readability_content, content_digests, node_indexes): # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') + soup = BeautifulSoup(readability_content, "html.parser") # Make all elements plain elements = plain_elements(soup.contents, content_digests, node_indexes) if node_indexes: @@ -217,8 +220,7 @@ def plain_content(readability_content, content_digests, node_indexes): def plain_elements(elements, content_digests, node_indexes): # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] + elements = [plain_element(element, content_digests, node_indexes) for element in elements] if content_digests: # Add content digest attribute to nodes elements = [add_content_digest(element) for element in elements] @@ -258,11 +260,9 @@ def add_node_indexes(element, node_index="0"): # Add index to current element element["data-node-index"] = node_index # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) add_node_indexes(child, node_index=child_index) return element @@ -284,11 +284,16 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} - retained_chars = ['\t', '\n', '\r', '\f'] + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) def normalize_unicode(text): @@ -305,8 +310,9 @@ def normalize_whitespace(text): text = text.strip() return text + def is_leaf(element): - return (element.name in ['p', 'li']) + return element.name in ["p", "li"] def is_text(element): @@ -330,7 +336,7 @@ def content_digest(element): if trimmed_string == "": digest = "" else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() else: contents = element.contents num_contents = len(contents) @@ -343,9 +349,8 @@ def content_digest(element): else: # Build content digest from the "non-empty" digests of child nodes digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) for child in child_digests: - digest.update(child.encode('utf-8')) + digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index ff5505bbbfb9f5..94d9fd9eb90e83 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -10,27 +10,25 @@ def check_parameter_configurations(cls, configurations: list[dict]): """ for configuration in configurations: if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError('invalid parameter configuration') + raise ValueError("invalid parameter configuration") @classmethod def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: """ get workflow graph variables """ - nodes = graph.get('nodes', []) - start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) if not start_node: return [] - return [ - VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) - ] - + return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])] + @classmethod - def check_is_synced(cls, - variables: list[VariableEntity], - tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ) -> None: """ check is synced @@ -39,10 +37,10 @@ def check_is_synced(cls, variable_names = [variable.variable for variable in variables] if len(tool_configurations) != len(variables): - raise ValueError('parameter configuration mismatch, please republish the tool to update') - + raise ValueError("parameter configuration mismatch, please republish the tool to update") + for parameter in tool_configurations: if parameter.name not in variable_names: - raise ValueError('parameter configuration mismatch, please republish the tool to update') + raise ValueError("parameter configuration mismatch, please republish the tool to update") - return True \ No newline at end of file + return True diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index f751c430963cda..bcb061376dc5a2 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :return: an object of the YAML content """ try: - with open(file_path, encoding='utf-8') as yaml_file: + with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) return yaml_content if yaml_content else default_value except Exception as e: - raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") except Exception as e: if ignore_error: return default_value diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9015eea85ce9db..83086d1afc9018 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -5,10 +5,7 @@ class WorkflowCallback(ABC): @abstractmethod - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: """ Published event """ diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index e7e6710cbdf8a0..2a864dd7a84c8b 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -8,9 +8,11 @@ class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + class BaseIterationNodeData(BaseNodeData): start_node_id: Optional[str] = None + class BaseIterationState(BaseModel): iteration_node_id: str index: int @@ -19,4 +21,4 @@ class BaseIterationState(BaseModel): class MetaData(BaseModel): pass - metadata: MetaData \ No newline at end of file + metadata: MetaData diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5e2a5cb4663e1b..5353b99ed38b14 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,28 +12,28 @@ class NodeType(Enum): Node Types. """ - START = 'start' - END = 'end' - ANSWER = 'answer' - LLM = 'llm' - KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' - IF_ELSE = 'if-else' - CODE = 'code' - TEMPLATE_TRANSFORM = 'template-transform' - QUESTION_CLASSIFIER = 'question-classifier' - HTTP_REQUEST = 'http-request' - TOOL = 'tool' - VARIABLE_AGGREGATOR = 'variable-aggregator' + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = 'variable-assigner' - LOOP = 'loop' - ITERATION = 'iteration' - ITERATION_START = 'iteration-start' # fake start node for iteration - PARAMETER_EXTRACTOR = 'parameter-extractor' - CONVERSATION_VARIABLE_ASSIGNER = 'assigner' + VARIABLE_ASSIGNER = "variable-assigner" + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # fake start node for iteration + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" @classmethod - def value_of(cls, value: str) -> 'NodeType': + def value_of(cls, value: str) -> "NodeType": """ Get value of given node type. @@ -43,7 +43,7 @@ def value_of(cls, value: str) -> 'NodeType': for node_type in cls: if node_type.value == value: return node_type - raise ValueError(f'invalid node type value {value}') + raise ValueError(f"invalid node type value {value}") class NodeRunMetadataKey(Enum): @@ -51,16 +51,16 @@ class NodeRunMetadataKey(Enum): Node Run Metadata Key. """ - TOTAL_TOKENS = 'total_tokens' - TOTAL_PRICE = 'total_price' - CURRENCY = 'currency' - TOOL_INFO = 'tool_info' - ITERATION_ID = 'iteration_id' - ITERATION_INDEX = 'iteration_index' - PARALLEL_ID = 'parallel_id' - PARALLEL_START_NODE_ID = 'parallel_start_node_id' - PARENT_PARALLEL_ID = 'parent_parallel_id' - PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" class NodeRunResult(BaseModel): @@ -85,6 +85,7 @@ class UserFrom(Enum): """ User from """ + ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a6171a4..1dfb1852f8ccc3 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -5,5 +5,6 @@ class VariableSelector(BaseModel): """ Variable Selector. """ + variable: str value_selector: list[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 48a20d25ae422c..b94b7f7198d988 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -23,23 +23,19 @@ class VariablePool(BaseModel): # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: dict[str, dict[int, Segment]] = Field( - description='Variables mapping', - default=defaultdict(dict) + description="Variables mapping", default=defaultdict(dict) ) # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( - description='User inputs', + description="User inputs", ) system_variables: Mapping[SystemVariableKey, Any] = Field( - description='System variables', + description="System variables", ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list - ) + environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) conversation_variables: Sequence[Variable] | None = None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 4bf4e454bbfde3..0a1eb57de4d003 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -46,13 +46,16 @@ class NodeRun(BaseModel): current_iteration_state: Optional[BaseIterationState] - def __init__(self, workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int): + def __init__( + self, + workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int, + ): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 4099def4e2ab57..697392b2a3c23f 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -8,19 +8,13 @@ class RunConditionHandler(ABC): - def __init__(self, - init_params: GraphInitParams, - graph: Graph, - condition: RunCondition): + def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition): self.init_params = init_params self.graph = graph self.condition = condition @abstractmethod - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 705eb908b1607a..af695df7d84607 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -4,10 +4,7 @@ class BranchIdentifyRunConditionHandler(RunConditionHandler): - - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index 1edaf92da7afab..eda5fe079ca179 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,10 +5,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed @@ -22,8 +19,7 @@ def check(self, # process condition condition_processor = ConditionProcessor() input_conditions, group_result = condition_processor.process_conditions( - variable_pool=graph_runtime_state.variable_pool, - conditions=self.condition.conditions + variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions ) # Apply the logical operator for the current case diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py index 2eb2e58bfcabdf..1c9237d82fbe68 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -9,9 +9,7 @@ class ConditionManager: @staticmethod def get_condition_handler( - init_params: GraphInitParams, - graph: Graph, - run_condition: RunCondition + init_params: GraphInitParams, graph: Graph, run_condition: RunCondition ) -> RunConditionHandler: """ Get condition handler @@ -22,14 +20,6 @@ def get_condition_handler( :return: condition handler """ if run_condition.type == "branch_identify": - return BranchIdentifyRunConditionHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition) else: - return ConditionRunConditionHandlerHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 49007b870d7f8f..b54595f7802d9b 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -34,38 +34,25 @@ class Graph(BaseModel): root_node_id: str = Field(..., description="root node id of the graph") node_ids: list[str] = Field(default_factory=list, description="graph node ids") node_id_config_mapping: dict[str, dict] = Field( - default_factory=list, - description="node configs mapping (node id: node config)" + default_factory=list, description="node configs mapping (node id: node config)" ) edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="graph edge mapping (source node id: edges)" + default_factory=dict, description="graph edge mapping (source node id: edges)" ) reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="reverse graph edge mapping (target node id: edges)" + default_factory=dict, description="reverse graph edge mapping (target node id: edges)" ) parallel_mapping: dict[str, GraphParallel] = Field( - default_factory=dict, - description="graph parallel mapping (parallel id: parallel)" + default_factory=dict, description="graph parallel mapping (parallel id: parallel)" ) node_parallel_mapping: dict[str, str] = Field( - default_factory=dict, - description="graph node parallel mapping (node id: parallel id)" - ) - answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( - ..., - description="answer stream generate routes" - ) - end_stream_param: EndStreamParam = Field( - ..., - description="end stream param" + default_factory=dict, description="graph node parallel mapping (node id: parallel id)" ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes") + end_stream_param: EndStreamParam = Field(..., description="end stream param") @classmethod - def init(cls, - graph_config: Mapping[str, Any], - root_node_id: Optional[str] = None) -> "Graph": + def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": """ Init graph @@ -74,7 +61,7 @@ def init(cls, :return: graph """ # edge configs - edge_configs = graph_config.get('edges') + edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] @@ -85,14 +72,14 @@ def init(cls, reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() for edge_config in edge_configs: - source_node_id = edge_config.get('source') + source_node_id = edge_config.get("source") if not source_node_id: continue if source_node_id not in edge_mapping: edge_mapping[source_node_id] = [] - target_node_id = edge_config.get('target') + target_node_id = edge_config.get("target") if not target_node_id: continue @@ -107,23 +94,18 @@ def init(cls, # parse run condition run_condition = None - if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': - run_condition = RunCondition( - type='branch_identify', - branch_identify=edge_config.get('sourceHandle') - ) + if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) # node configs - node_configs = graph_config.get('nodes') + node_configs = graph_config.get("nodes") if not node_configs: raise ValueError("Graph must have at least one node") @@ -133,7 +115,7 @@ def init(cls, root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} for node_config in node_configs: - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: continue @@ -142,30 +124,29 @@ def init(cls, all_node_id_config_mapping[node_id] = node_config - root_node_ids = [node_config.get('id') for node_config in root_node_configs] + root_node_ids = [node_config.get("id") for node_config in root_node_configs] # fetch root node if not root_node_id: # if no root node id, use the START type node as root node - root_node_id = next((node_config.get("id") for node_config in root_node_configs - if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + root_node_id = next( + ( + node_config.get("id") + for node_config in root_node_configs + if node_config.get("data", {}).get("type", "") == NodeType.START.value + ), + None, + ) if not root_node_id or root_node_id not in root_node_ids: raise ValueError(f"Root node id {root_node_id} not found in the graph") - + # Check whether it is connected to the previous node - cls._check_connected_to_previous_node( - route=[root_node_id], - edge_mapping=edge_mapping - ) + cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping) # fetch all node ids from root node node_ids = [root_node_id] - cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=root_node_id - ) + cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id) node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} @@ -177,29 +158,26 @@ def init(cls, reverse_edge_mapping=reverse_edge_mapping, start_node_id=root_node_id, parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # Check if it exceeds N layers of parallel for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, - level_limit=3, - parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id ) # init answer stream generate routes answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping + node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping ) # init end stream param end_stream_param = EndStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # init graph @@ -212,14 +190,14 @@ def init(cls, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, answer_stream_generate_routes=answer_stream_generate_routes, - end_stream_param=end_stream_param + end_stream_param=end_stream_param, ) return graph - def add_extra_edge(self, source_node_id: str, - target_node_id: str, - run_condition: Optional[RunCondition] = None) -> None: + def add_extra_edge( + self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None + ) -> None: """ Add extra edge to the graph @@ -237,9 +215,7 @@ def add_extra_edge(self, source_node_id: str, return graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) self.edge_mapping[source_node_id].append(graph_edge) @@ -254,17 +230,18 @@ def get_leaf_node_ids(self) -> list[str]: for node_id in self.node_ids: if node_id not in self.edge_mapping: leaf_node_ids.append(node_id) - elif (len(self.edge_mapping[node_id]) == 1 - and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + elif ( + len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id + ): leaf_node_ids.append(node_id) return leaf_node_ids @classmethod - def _recursively_add_node_ids(cls, - node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - node_id: str) -> None: + def _recursively_add_node_ids( + cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str + ) -> None: """ Recursively add node ids @@ -278,17 +255,11 @@ def _recursively_add_node_ids(cls, node_ids.append(graph_edge.target_node_id) cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=graph_edge.target_node_id + node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id ) @classmethod - def _check_connected_to_previous_node( - cls, - route: list[str], - edge_mapping: dict[str, list[GraphEdge]] - ) -> None: + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: """ Check whether it is connected to the previous node """ @@ -299,7 +270,9 @@ def _check_connected_to_previous_node( continue if graph_edge.target_node_id in route: - raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + raise ValueError( + f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." + ) new_route = route[:] new_route.append(graph_edge.target_node_id) @@ -316,7 +289,7 @@ def _recursively_add_parallels( start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], - parent_parallel: Optional[GraphParallel] = None + parent_parallel: Optional[GraphParallel] = None, ) -> None: """ Recursively add parallel ids @@ -355,14 +328,14 @@ def _recursively_add_parallels( parallel = GraphParallel( start_from_node_id=start_node_id, parent_parallel_id=parent_parallel.id if parent_parallel else None, - parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, ) parallel_mapping[parallel.id] = parallel in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( edge_mapping=edge_mapping, reverse_edge_mapping=reverse_edge_mapping, - parallel_branch_node_ids=parallel_branch_node_ids + parallel_branch_node_ids=parallel_branch_node_ids, ) # collect all branches node ids @@ -403,14 +376,25 @@ def _recursively_add_parallels( continue if ( - (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) - or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) ): outside_parallel_target_node_ids.add(target_node_id) if len(outside_parallel_target_node_ids) == 1: - if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): parallel.end_to_node_id = None else: parallel.end_to_node_id = outside_parallel_target_node_ids.pop() @@ -420,18 +404,20 @@ def _recursively_add_parallels( if parallel: current_parallel = parallel elif parent_parallel: - if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): current_parallel = parent_parallel else: # fetch parent parallel's parent parallel parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id if parent_parallel_parent_parallel_id: parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) - if ( - parent_parallel_parent_parallel - and ( - not parent_parallel_parent_parallel.end_to_node_id - or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id + or ( + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id ) ): current_parallel = parent_parallel_parent_parallel @@ -442,7 +428,7 @@ def _recursively_add_parallels( start_node_id=graph_edge.target_node_id, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel + parent_parallel=current_parallel, ) @classmethod @@ -451,7 +437,7 @@ def _check_exceed_parallel_limit( parallel_mapping: dict[str, GraphParallel], level_limit: int, parent_parallel_id: str, - current_level: int = 1 + current_level: int = 1, ) -> None: """ Check if it exceeds N layers of parallel @@ -459,25 +445,27 @@ def _check_exceed_parallel_limit( parent_parallel = parallel_mapping.get(parent_parallel_id) if not parent_parallel: return - + current_level += 1 if current_level > level_limit: raise ValueError(f"Exceeds {level_limit} layers of parallel") - + if parent_parallel.parent_parallel_id: cls._check_exceed_parallel_limit( parallel_mapping=parallel_mapping, level_limit=level_limit, parent_parallel_id=parent_parallel.parent_parallel_id, - current_level=current_level + current_level=current_level, ) @classmethod - def _recursively_add_parallel_node_ids(cls, - branch_node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - merge_node_id: str, - start_node_id: str) -> None: + def _recursively_add_parallel_node_ids( + cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str, + ) -> None: """ Recursively add node ids @@ -487,21 +475,22 @@ def _recursively_add_parallel_node_ids(cls, :param start_node_id: start node id """ for graph_edge in edge_mapping.get(start_node_id, []): - if (graph_edge.target_node_id != merge_node_id - and graph_edge.target_node_id not in branch_node_ids): + if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids: branch_node_ids.append(graph_edge.target_node_id) cls._recursively_add_parallel_node_ids( branch_node_ids=branch_node_ids, edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=graph_edge.target_node_id + start_node_id=graph_edge.target_node_id, ) @classmethod - def _fetch_all_node_ids_in_parallels(cls, - edge_mapping: dict[str, list[GraphEdge]], - reverse_edge_mapping: dict[str, list[GraphEdge]], - parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: + def _fetch_all_node_ids_in_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str], + ) -> dict[str, list[str]]: """ Fetch all node ids in parallels """ @@ -513,7 +502,7 @@ def _fetch_all_node_ids_in_parallels(cls, cls._recursively_fetch_routes( edge_mapping=edge_mapping, start_node_id=parallel_branch_node_id, - routes_node_ids=routes_node_ids[parallel_branch_node_id] + routes_node_ids=routes_node_ids[parallel_branch_node_id], ) # fetch leaf node ids from routes node ids @@ -529,13 +518,13 @@ def _fetch_all_node_ids_in_parallels(cls, for branch_node_id2, inner_route2 in routes_node_ids.items(): if ( - branch_node_id != branch_node_id2 + branch_node_id != branch_node_id2 and node_id in inner_route2 and len(reverse_edge_mapping.get(node_id, [])) > 1 and cls._is_node_in_routes( reverse_edge_mapping=reverse_edge_mapping, start_node_id=node_id, - routes_node_ids=routes_node_ids + routes_node_ids=routes_node_ids, ) ): if node_id not in merge_branch_node_ids: @@ -551,23 +540,18 @@ def _fetch_all_node_ids_in_parallels(cls, for node_id, branch_node_ids in merge_branch_node_ids.items(): for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): - if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + if (node_id, node_id2) not in duplicate_end_node_ids and ( + node_id2, + node_id, + ) not in duplicate_end_node_ids: duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids - + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after - if cls._is_node2_after_node1( - node1_id=node_id, - node2_id=node_id2, - edge_mapping=edge_mapping - ): + if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): if node_id in merge_branch_node_ids: del merge_branch_node_ids[node_id2] - elif cls._is_node2_after_node1( - node1_id=node_id2, - node2_id=node_id, - edge_mapping=edge_mapping - ): + elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): if node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] @@ -599,16 +583,15 @@ def _fetch_all_node_ids_in_parallels(cls, branch_node_ids=in_branch_node_ids[branch_node_id], edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=branch_node_id + start_node_id=branch_node_id, ) return in_branch_node_ids @classmethod - def _recursively_fetch_routes(cls, - edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: list[str]) -> None: + def _recursively_fetch_routes( + cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] + ) -> None: """ Recursively fetch route """ @@ -621,28 +604,25 @@ def _recursively_fetch_routes(cls, routes_node_ids.append(graph_edge.target_node_id) cls._recursively_fetch_routes( - edge_mapping=edge_mapping, - start_node_id=graph_edge.target_node_id, - routes_node_ids=routes_node_ids + edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids ) @classmethod - def _is_node_in_routes(cls, - reverse_edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: dict[str, list[str]]) -> bool: + def _is_node_in_routes( + cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]] + ) -> bool: """ Recursively check if the node is in the routes """ if start_node_id not in reverse_edge_mapping: return False - + all_routes_node_ids = set() parallel_start_node_ids: dict[str, list[str]] = {} for branch_node_id, node_ids in routes_node_ids.items(): for node_id in node_ids: all_routes_node_ids.add(node_id) - + if branch_node_id in reverse_edge_mapping: for graph_edge in reverse_edge_mapping[branch_node_id]: if graph_edge.source_node_id not in parallel_start_node_ids: @@ -655,38 +635,34 @@ def _is_node_in_routes(cls, if set(branch_node_ids) == set(routes_node_ids.keys()): parallel_start_node_id = p_start_node_id return True - + if not parallel_start_node_id: raise Exception("Parallel start node id not found") - + for graph_edge in reverse_edge_mapping[start_node_id]: - if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: + if ( + graph_edge.source_node_id not in all_routes_node_ids + or graph_edge.source_node_id != parallel_start_node_id + ): return False - + return True @classmethod - def _is_node2_after_node1( - cls, - node1_id: str, - node2_id: str, - edge_mapping: dict[str, list[GraphEdge]] - ) -> bool: + def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: """ is node2 after node1 """ if node1_id not in edge_mapping: return False - + for graph_edge in edge_mapping[node1_id]: if graph_edge.target_node_id == node2_id: return True - + if cls._is_node2_after_node1( - node1_id=graph_edge.target_node_id, - node2_id=node2_id, - edge_mapping=edge_mapping + node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping ): return True - - return False \ No newline at end of file + + return False diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index c7d484ddf55f7c..afc09bfac5b0c1 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -10,7 +10,7 @@ class GraphRuntimeState(BaseModel): variable_pool: VariablePool = Field(..., description="variable pool") """variable pool""" - + start_at: float = Field(..., description="start time") """start time""" total_tokens: int = 0 diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index 0362343568dd56..eedce8842b411e 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -18,4 +18,4 @@ class RunCondition(BaseModel): @property def hash(self) -> str: - return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index b5d6e4c09d70e2..8fc80474265d9f 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -68,13 +68,11 @@ def set_finished(self, run_result: NodeRunResult) -> None: class RuntimeRouteState(BaseModel): routes: dict[str, list[str]] = Field( - default_factory=dict, - description="graph state routes (source_node_state_id: target_node_state_id)" + default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" ) node_state_mapping: dict[str, RouteNodeState] = Field( - default_factory=dict, - description="node state mapping (route_node_state_id: route_node_state)" + default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" ) def create_node_state(self, node_id: str) -> RouteNodeState: @@ -99,13 +97,13 @@ def add_route(self, source_node_state_id: str, target_node_state_id: str) -> Non self.routes[source_node_state_id].append(target_node_state_id) - def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ - -> list[RouteNodeState]: + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]: """ Get routes with node state by source node id :param source_node_state_id: source node state id :return: routes with node state """ - return [self.node_state_mapping[target_state_id] - for target_state_id in self.routes.get(source_node_state_id, [])] + return [ + self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) + ] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 65d9ab8446ef1c..c6bd122b3704db 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -48,8 +48,9 @@ class GraphEngineThreadPool(ThreadPoolExecutor): - def __init__(self, max_workers=None, thread_name_prefix='', - initializer=None, initargs=(), max_submit_count=100) -> None: + def __init__( + self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + ) -> None: super().__init__(max_workers, thread_name_prefix, initializer, initargs) self.max_submit_count = max_submit_count self.submit_count = 0 @@ -57,9 +58,9 @@ def __init__(self, max_workers=None, thread_name_prefix='', def submit(self, fn, *args, **kwargs): self.submit_count += 1 self.check_is_full() - + return super().submit(fn, *args, **kwargs) - + def check_is_full(self) -> None: print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") if self.submit_count > self.max_submit_count: @@ -70,21 +71,21 @@ class GraphEngine: workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} def __init__( - self, - tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - graph: Graph, - graph_config: Mapping[str, Any], - variable_pool: VariablePool, - max_execution_steps: int, - max_execution_time: int, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None, ) -> None: thread_pool_max_submit_count = 100 thread_pool_max_workers = 10 @@ -93,12 +94,14 @@ def __init__( if thread_pool_id: if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") - + self.thread_pool_id = thread_pool_id self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] self.is_main_thread_pool = False else: - self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool = GraphEngineThreadPool( + max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count + ) self.thread_pool_id = str(uuid.uuid4()) self.is_main_thread_pool = True GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool @@ -113,13 +116,10 @@ def __init__( user_id=user_id, user_from=user_from, invoke_from=invoke_from, - call_depth=call_depth + call_depth=call_depth, ) - self.graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time @@ -136,37 +136,40 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: stream_processor_cls = EndStreamProcessor stream_processor = stream_processor_cls( - graph=self.graph, - variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool ) # run graph - generator = stream_processor.process( - self._run(start_node_id=self.graph.root_node_id) - ) + generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) for item in generator: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: - self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else {}) + self.graph_runtime_state.outputs = ( + item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {} + ) elif item.node_type == NodeType.ANSWER: if "answer" not in self.graph_runtime_state.outputs: self.graph_runtime_state.outputs["answer"] = "" - self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else "") - - self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() + self.graph_runtime_state.outputs["answer"] += "\n" + ( + item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "" + ) + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[ + "answer" + ].strip() except Exception as e: logger.exception(f"Graph run failed: {str(e)}") yield GraphRunFailedEvent(error=str(e)) @@ -186,12 +189,12 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] def _run( - self, - start_node_id: str, - in_parallel_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None - ) -> Generator[GraphEngineEvent, None, None]: + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: parallel_start_node_id = start_node_id @@ -201,31 +204,28 @@ def _run( while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) # or max execution time reached if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, - max_execution_time=self.max_execution_time + start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time ): - raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) # init route node state - route_node_state = self.graph_runtime_state.node_run_state.create_node_state( - node_id=next_node_id - ) + route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) # get node config node_id = route_node_state.node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} config not found.') + raise GraphRunFailedError(f"Node {node_id} config not found.") # convert to specific node - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: - raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None @@ -237,7 +237,7 @@ def _run( graph=self.graph, graph_runtime_state=self.graph_runtime_state, previous_node_id=previous_node_id, - thread_pool_id=self.thread_pool_id + thread_pool_id=self.thread_pool_id, ) try: @@ -248,7 +248,7 @@ def _run( parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: @@ -263,8 +263,7 @@ def _run( # append route if previous_route_node_state: self.graph_runtime_state.node_run_state.add_route( - source_node_state_id=previous_route_node_state.id, - target_node_state_id=route_node_state.id + source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id ) except Exception as e: route_node_state.status = RouteNodeState.Status.FAILED @@ -279,13 +278,15 @@ def _run( parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) raise e # It may not be necessary, but it is necessary. :) - if (self.graph.node_id_config_mapping[next_node_id] - .get("data", {}).get("type", "").lower() == NodeType.END.value): + if ( + self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() + == NodeType.END.value + ): break previous_route_node_state = route_node_state @@ -305,7 +306,7 @@ def _run( run_condition=edge.run_condition, ).check( graph_runtime_state=self.graph_runtime_state, - previous_route_node_state=previous_route_node_state + previous_route_node_state=previous_route_node_state, ) if not result: @@ -343,14 +344,14 @@ def _run( if not result: continue - + if len(sub_edge_mappings) == 1: final_node_id = edge.target_node_id else: parallel_generator = self._run_parallel_branches( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -369,7 +370,7 @@ def _run( parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -383,14 +384,14 @@ def _run( next_node_id = final_node_id - if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id: break def _run_parallel_branches( - self, - edge_mappings: list[GraphEdge], - in_parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -398,14 +399,18 @@ def _run_parallel_branches( node_id = edge_mappings[0].target_node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') + raise GraphRunFailedError( + f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches." + ) - node_title = node_config.get('data', {}).get('title') - raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') + node_title = node_config.get("data", {}).get("title") + raise GraphRunFailedError( + f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches." + ) parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: - raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + raise GraphRunFailedError(f"Parallel {parallel_id} not found.") # run parallel nodes, run in new thread and use queue to get results q: queue.Queue = queue.Queue() @@ -417,19 +422,22 @@ def _run_parallel_branches( for edge in edge_mappings: if ( edge.target_node_id not in self.graph.node_parallel_mapping - or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id ): continue futures.append( - self.thread_pool.submit(self._run_parallel_node, **{ - 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] - 'q': q, - 'parallel_id': parallel_id, - 'parallel_start_node_id': edge.target_node_id, - 'parent_parallel_id': in_parallel_id, - 'parent_parallel_start_node_id': parallel_start_node_id, - }) + self.thread_pool.submit( + self._run_parallel_node, + **{ + "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] + "q": q, + "parallel_id": parallel_id, + "parallel_start_node_id": edge.target_node_id, + "parent_parallel_id": in_parallel_id, + "parent_parallel_start_node_id": parallel_start_node_id, + }, + ) ) succeeded_count = 0 @@ -451,7 +459,7 @@ def _run_parallel_branches( raise GraphRunFailedError(event.error) except queue.Empty: continue - + # wait all threads wait(futures) @@ -461,72 +469,80 @@ def _run_parallel_branches( yield final_node_id def _run_parallel_node( - self, - flask_app: Flask, - q: queue.Queue, - parallel_id: str, - parallel_start_node_id: str, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> None: """ Run parallel nodes """ with flask_app.app_context(): try: - q.put(ParallelBranchRunStartedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) # run node generator = self._run( start_node_id=parallel_start_node_id, in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: q.put(item) # trigger graph run success event - q.put(ParallelBranchRunSucceededEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) except GraphRunFailedError as e: - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=e.error - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error, + ) + ) except Exception as e: logger.exception("Unknown Error when generating in parallel") - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=str(e) - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e), + ) + ) finally: db.session.remove() def _run_node( - self, - node_instance: BaseNode, - route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -542,7 +558,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) db.session.close() @@ -567,7 +583,7 @@ def _run_node( if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( - error=route_node_state.failed_reason or 'Unknown error.', + error=route_node_state.failed_reason or "Unknown error.", id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, @@ -576,7 +592,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): @@ -596,7 +612,7 @@ def _run_node( self._append_variables_recursively( node_id=node_instance.node_id, variable_key_list=[variable_key], - variable_value=variable_value + variable_value=variable_value, ) # add parallel info to run result metadata @@ -608,7 +624,9 @@ def _run_node( run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) yield NodeRunSucceededEvent( id=node_instance.id, @@ -619,7 +637,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) break @@ -635,7 +653,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -649,7 +667,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) except GenerateTaskStoppedException: # trigger node run failed event @@ -665,7 +683,7 @@ def _run_node( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) return except Exception as e: @@ -674,10 +692,7 @@ def _run_node( finally: db.session.close() - def _append_variables_recursively(self, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): + def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ Append variables recursively :param node_id: node id @@ -685,10 +700,7 @@ def _append_variables_recursively(self, :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add( - [node_id] + variable_key_list, - variable_value - ) + self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) # if variable_value is a dict, then recursively append variables if isinstance(variable_value, dict): @@ -696,9 +708,7 @@ def _append_variables_recursively(self, # construct new key list new_key_list = variable_key_list + [key] self._append_variables_recursively( - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value + node_id=node_id, variable_key_list=new_key_list, variable_value=value ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 8cf01727ec9296..deacbbbbb0cbee 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -29,14 +29,12 @@ def _run(self) -> NodeRunResult: # generate routes generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) - answer = '' + answer = "" for part in generate_routes: if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get( - value_selector - ) + value = self.graph_runtime_state.variable_pool.get(value_selector) if value: answer += value.markdown @@ -44,19 +42,11 @@ def _run(self) -> NodeRunResult: part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "answer": answer - } - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -73,6 +63,6 @@ def _extract_variable_selector_to_variable_mapping( variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 6cb80091c91fbd..2562b9ce960ec3 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -1,4 +1,3 @@ - from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import ( @@ -12,12 +11,12 @@ class AnswerStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] - ) -> AnswerStreamGenerateRoute: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: """ Get stream generate routes. :return: @@ -25,7 +24,7 @@ def init(cls, # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} for answer_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: continue # get generate route for stream output @@ -37,12 +36,11 @@ def init(cls, answer_dependencies = cls._fetch_answers_dependencies( answer_node_ids=answer_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, - answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies ) @classmethod @@ -56,8 +54,7 @@ def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> lis variable_selectors = variable_template_parser.extract_variable_selectors() value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors + variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors } variable_keys = list(value_selector_mapping.keys()) @@ -71,21 +68,17 @@ def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> lis template = node_data.answer for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω") generate_routes: list[GenerateRouteChunk] = [] - for part in template.split('Ω'): + for part in template.split("Ω"): if part: if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "") value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) + generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector)) else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) + generate_routes.append(TextGenerateRouteChunk(text=part)) return generate_routes @@ -101,15 +94,16 @@ def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteCh @classmethod def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys + cleaned_part = part.replace("{{", "").replace("}}", "") + return part.startswith("{{") and cleaned_part in variable_keys @classmethod - def _fetch_answers_dependencies(cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_answers_dependencies( + cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch answer dependencies :param answer_node_ids: answer node ids @@ -127,19 +121,20 @@ def _fetch_answers_dependencies(cls, answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) return answer_dependencies @classmethod - def _recursive_fetch_answer_dependencies(cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_answer_dependencies( + cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch answer dependencies :param current_node_id: current node id @@ -152,11 +147,11 @@ def _recursive_fetch_answer_dependencies(cls, reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in ( - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, ): answer_dependencies[answer_node_id].append(source_node_id) else: @@ -165,5 +160,5 @@ def _recursive_fetch_answer_dependencies(cls, answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index c2a5dd5163819d..9776ce5810dbaa 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -18,7 +18,6 @@ class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes @@ -27,9 +26,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -47,9 +44,9 @@ def process(self, ] else: stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_answer_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_answer_node_ids + ) for _ in stream_out_answer_node_ids: yield event @@ -77,9 +74,9 @@ def reset(self) -> None: self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -87,10 +84,13 @@ def _generate_stream_outputs_when_node_finished(self, """ for answer_node_id, position in self.route_position.items(): # all depends on answer node id not in rest node ids - if (event.route_node_state.node_id != answer_node_id - and (answer_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + if event.route_node_state.node_id != answer_node_id and ( + answer_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ) + ): continue route_position = self.route_position[answer_node_id] @@ -115,9 +115,7 @@ def _generate_stream_outputs_when_node_finished(self, if not value_selector: break - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -158,8 +156,9 @@ def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> lis continue # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if all( + dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ): if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): continue @@ -213,7 +212,7 @@ def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index cbabbca37d0558..36c3fe180a9cb2 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -7,16 +7,13 @@ class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.graph = graph self.variable_pool = variable_pool self.rest_node_ids = graph.node_ids.copy() @abstractmethod - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: @@ -35,9 +32,11 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: reachable_node_ids = [] unreachable_first_node_ids = [] for edge in self.graph.edge_mapping[finished_node_id]: - if (edge.run_condition - and edge.run_condition.branch_identify - and run_result.edge_source_handle == edge.run_condition.branch_identify): + if ( + edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify + ): reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 620c2c426bae9e..e356e7fd702245 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ + answer: str = Field(..., description="answer template string") @@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" value_selector: list[str] = Field(..., description="value selector") @@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT """generate route chunk type""" text: str = Field(..., description="text") @@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel): """ AnswerStreamGenerateRoute entity """ + answer_dependencies: dict[str, list[str]] = Field( - ..., - description="answer dependencies (answer node id -> dependent answer node ids)" + ..., description="answer dependencies (answer node id -> dependent answer node ids)" ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., - description="answer generate route (answer node id -> generate route chunks)" + ..., description="answer generate route (answer node id -> generate route chunks)" ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index b9912314f136f3..7bfe45a13c577d 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -15,14 +15,16 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - def __init__(self, - id: str, - config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None) -> None: + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -46,8 +48,7 @@ def __init__(self, self.node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod - def _run(self) \ - -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node :return: @@ -62,14 +63,14 @@ def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: result = self._run() if isinstance(result, NodeRunResult): - yield RunCompletedEvent( - run_result=result - ) + yield RunCompletedEvent(run_result=result) else: yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: + def extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], config: dict + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping :param graph_config: graph config @@ -82,17 +83,12 @@ def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=node_data ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: BaseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 955afdfa1dd4a1..4a1787c8c17ca1 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -25,11 +25,10 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) + code_language = filters.get("code_language", CodeLanguage.PYTHON3) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers - if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) return code_provider.get_default_config() @@ -63,17 +62,9 @@ def _run(self) -> NodeRunResult: # Transform result result = self._transform_result(result, node_data.outputs) except (CodeExecutionException, ValueError) as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=result - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) def _check_string(self, value: str, variable: str) -> str: """ @@ -87,12 +78,14 @@ def _check_string(self, value: str, variable: str) -> str: return None else: raise ValueError(f"Output variable `{variable}` must be a string") - + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be' - f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') + raise ValueError( + f"The length of output variable `{variable}` must be" + f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + ) - return value.replace('\x00', '') + return value.replace("\x00", "") def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ @@ -108,20 +101,24 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f raise ValueError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range,' - f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') + raise ValueError( + f"Output variable `{variable}` is out of range," + f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + ) if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision,' - f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') + if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError( + f"Output variable `{variable}` has too high precision," + f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + ) return value - def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], - prefix: str = '', - depth: int = 1) -> dict: + def _transform_result( + self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 + ) -> dict: """ Transform result :param result: result @@ -139,183 +136,181 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code self._transform_result( result=output_value, output_schema=None, - prefix=f'{prefix}.{output_name}' if prefix else output_name, - depth=depth + 1 + prefix=f"{prefix}.{output_name}" if prefix else output_name, + depth=depth + 1, ) elif isinstance(output_value, int | float): self._check_number( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, str): self._check_string( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all( + value is None or isinstance(value, int | float) for value in output_value + ): for i, value in enumerate(output_value): self._check_number( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all( + value is None or isinstance(value, str) for value in output_value + ): for i, value in enumerate(output_value): self._check_string( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all( + value is None or isinstance(value, dict) for value in output_value + ): for i, value in enumerate(output_value): if value is not None: self._transform_result( result=value, output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", + depth=depth + 1, ) else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + raise ValueError( + f"Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type." + ) elif isinstance(output_value, type(None)): pass else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') - + raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + return result parameters_validated = {} for output_name, output_config in output_schema.items(): - dot = '.' if prefix else '' + dot = "." if prefix else "" if output_name not in result: - raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') - - if output_config.type == 'object': + raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + + if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead." ) else: transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 + prefix=f"{prefix}.{output_name}", + depth=depth + 1, ) - elif output_config.type == 'number': + elif output_config.type == "number": # check if number available transformed_result[output_name] = self._check_number( - value=result[output_name], - variable=f'{prefix}{dot}{output_name}' + value=result[output_name], variable=f"{prefix}{dot}{output_name}" ) - elif output_config.type == 'string': + elif output_config.type == "string": # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], - variable=f'{prefix}{dot}{output_name}', + variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == 'array[number]': + elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[string]': + elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[object]': + elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) - + for i, value in enumerate(result[output_name]): if not isinstance(value, dict): if isinstance(value, type(None)): pass else: raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + f"Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}." ) transformed_result[output_name] = [ - None if value is None else self._transform_result( + None + if value is None + else self._transform_result( result=value, output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}{dot}{output_name}[{i}]", + depth=depth + 1, ) for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f'Output type {output_config.type} is not supported.') - + raise ValueError(f"Output type {output_config.type} is not supported.") + parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError('Not all output parameters are validated.') + raise ValueError("Not all output parameters are validated.") return transformed_result @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -325,5 +320,6 @@ def _extract_variable_selector_to_variable_mapping( :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c0701ecccd2132..5eb0e0f63f4f35 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData): """ Code Node Data. """ + class Output(BaseModel): - type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] - children: Optional[dict[str, 'Output']] = None + type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + children: Optional[dict[str, "Output"]] = None class Dependency(BaseModel): name: str @@ -23,4 +24,4 @@ class Dependency(BaseModel): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 552914b308a9b7..7b78d67be8780d 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -25,18 +25,11 @@ def _run(self) -> NodeRunResult: value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: EndNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 8390f6d81b4d31..30ce8fe018cef5 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -3,13 +3,13 @@ class EndStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_parallel_mapping: dict[str, str] - ) -> EndStreamParam: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str], + ) -> EndStreamParam: """ Get stream generate routes. :return: @@ -17,7 +17,7 @@ def init(cls, # parse stream output node value selector of end nodes end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} for end_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.END.value: + if not node_config.get("data", {}).get("type") == NodeType.END.value: continue # skip end node in parallel @@ -33,18 +33,18 @@ def init(cls, end_dependencies = cls._fetch_ends_dependencies( end_node_ids=end_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return EndStreamParam( end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) @classmethod - def extract_stream_variable_selector_from_node_data(cls, - node_id_config_mapping: dict[str, dict], - node_data: EndNodeData) -> list[list[str]]: + def extract_stream_variable_selector_from_node_data( + cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData + ) -> list[list[str]]: """ Extract stream variable selector from node data :param node_id_config_mapping: node id config mapping @@ -59,21 +59,22 @@ def extract_stream_variable_selector_from_node_data(cls, continue node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_id_config_mapping: + if node_id != "sys" and node_id in node_id_config_mapping: node = node_id_config_mapping[node_id] - node_type = node.get('data', {}).get('type') + node_type = node.get("data", {}).get("type") if ( variable_selector.value_selector not in value_selectors - and node_type == NodeType.LLM.value - and variable_selector.value_selector[1] == 'text' + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == "text" ): value_selectors.append(variable_selector.value_selector) return value_selectors @classmethod - def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ - -> list[list[str]]: + def _extract_stream_variable_selector( + cls, node_id_config_mapping: dict[str, dict], config: dict + ) -> list[list[str]]: """ Extract stream variable selector from node config :param node_id_config_mapping: node id config mapping @@ -84,11 +85,12 @@ def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dic return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) @classmethod - def _fetch_ends_dependencies(cls, - end_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_ends_dependencies( + cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch end dependencies :param end_node_ids: end node ids @@ -106,20 +108,21 @@ def _fetch_ends_dependencies(cls, end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) return end_dependencies @classmethod - def _recursive_fetch_end_dependencies(cls, - current_node_id: str, - end_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] - end_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_end_dependencies( + cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch end dependencies :param current_node_id: current node id @@ -132,10 +135,10 @@ def _recursive_fetch_end_dependencies(cls, reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in ( - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, ): end_dependencies[end_node_id].append(source_node_id) else: @@ -144,5 +147,5 @@ def _recursive_fetch_end_dependencies(cls, end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 4474c2a78a2740..0366d7965d7c17 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -15,7 +15,6 @@ class EndStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.end_stream_param = graph.end_stream_param @@ -26,9 +25,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.has_outputed = False self.outputed_node_ids = set() - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -38,7 +35,7 @@ def process(self, elif isinstance(event, NodeRunStreamChunkEvent): if event.in_iteration_id: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -51,13 +48,13 @@ def process(self, ] else: stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_end_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_end_node_ids + ) if stream_out_end_node_ids: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -86,9 +83,9 @@ def reset(self) -> None: self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -96,10 +93,12 @@ def _generate_stream_outputs_when_node_finished(self, """ for end_node_id, position in self.route_position.items(): # all depends on end node id not in rest node ids - if (event.route_node_state.node_id != end_node_id - and (end_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + if event.route_node_state.node_id != end_node_id and ( + end_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id] + ) + ): continue route_position = self.route_position[end_node_id] @@ -116,9 +115,7 @@ def _generate_stream_outputs_when_node_finished(self, if not value_selector: continue - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -128,7 +125,7 @@ def _generate_stream_outputs_when_node_finished(self, if text: current_node_id = value_selector[0] if self.has_outputed and current_node_id not in self.outputed_node_ids: - text = '\n' + text + text = "\n" + text self.outputed_node_ids.add(current_node_id) self.has_outputed = True @@ -165,8 +162,7 @@ def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[s continue # all depends on end node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]): if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): continue @@ -178,7 +174,7 @@ def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[s break position += 1 - + if not value_selector: continue diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index a0edf7b5796bba..c3270ac22a4149 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData): """ END Node Data. """ + outputs: list[VariableSelector] @@ -15,11 +16,10 @@ class EndStreamParam(BaseModel): """ EndStreamParam entity """ + end_dependencies: dict[str, list[str]] = Field( - ..., - description="end dependencies (end node id -> dependent node ids)" + ..., description="end dependencies (end node id -> dependent node ids)" ) end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., - description="end stream variable selector mapping (end node id -> stream variable selectors)" + ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" ) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index c066d469d8c553..66dd1f2dc60d12 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -7,32 +7,32 @@ class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, 'basic', 'bearer', 'custom'] + type: Literal[None, "basic", "bearer", "custom"] api_key: Union[None, str] = None header: Union[None, str] = None class HttpRequestNodeAuthorization(BaseModel): - type: Literal['no-auth', 'api-key'] + type: Literal["no-auth", "api-key"] config: Optional[HttpRequestNodeAuthorizationConfig] = None - @field_validator('config', mode='before') + @field_validator("config", mode="before") @classmethod def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): """ Check config, if type is no-auth, config should be None, otherwise it should be a dict. """ - if values.data['type'] == 'no-auth': + if values.data["type"] == "no-auth": return None else: if not v or not isinstance(v, dict): - raise ValueError('config should be a dict') + raise ValueError("config should be a dict") return v class HttpRequestNodeBody(BaseModel): - type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] data: Union[None, str] = None @@ -47,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] + method: Literal["get", "post", "put", "patch", "delete", "head"] url: str authorization: HttpRequestNodeAuthorization headers: str diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index d16bff58bda12b..49102dc3ab127a 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -33,12 +33,12 @@ def is_file(self) -> bool: check if response is file """ content_type = self.get_content_type() - file_content_types = ['image', 'audio', 'video'] + file_content_types = ["image", "audio", "video"] return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - return self.headers.get('content-type', '') + return self.headers.get("content-type", "") def extract_file(self) -> tuple[str, bytes]: """ @@ -47,28 +47,28 @@ def extract_file(self) -> tuple[str, bytes]: if self.is_file: return self.get_content_type(), self.body - return '', b'' + return "", b"" @property def content(self) -> str: if isinstance(self.response, httpx.Response): return self.response.text else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def body(self) -> bytes: if isinstance(self.response, httpx.Response): return self.response.content else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def status_code(self) -> int: if isinstance(self.response, httpx.Response): return self.response.status_code else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def size(self) -> int: @@ -77,11 +77,11 @@ def size(self) -> int: @property def readable_size(self) -> str: if self.size < 1024: - return f'{self.size} bytes' + return f"{self.size} bytes" elif self.size < 1024 * 1024: - return f'{(self.size / 1024):.2f} KB' + return f"{(self.size / 1024):.2f} KB" else: - return f'{(self.size / 1024 / 1024):.2f} MB' + return f"{(self.size / 1024 / 1024):.2f} MB" class HttpExecutor: @@ -120,7 +120,7 @@ def _is_json_body(body: HttpRequestNodeBody): """ check if body is json """ - if body and body.type == 'json' and body.data: + if body and body.type == "json" and body.data: try: json.loads(body.data) return True @@ -134,15 +134,15 @@ def _to_dict(convert_text: str): """ Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` """ - kv_paris = convert_text.split('\n') + kv_paris = convert_text.split("\n") result = {} for kv in kv_paris: if not kv.strip(): continue - kv = kv.split(':', maxsplit=1) + kv = kv.split(":", maxsplit=1) if len(kv) == 1: - k, v = kv[0], '' + k, v = kv[0], "" else: k, v = kv result[k.strip()] = v @@ -166,31 +166,31 @@ def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional # check if it's a valid JSON is_valid_json = self._is_json_body(node_data.body) - body_data = node_data.body.data or '' + body_data = node_data.body.data or "" if body_data: body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - content_type_is_set = any(key.lower() == 'content-type' for key in self.headers) - if node_data.body.type == 'json' and not content_type_is_set: - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set: - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + content_type_is_set = any(key.lower() == "content-type" for key in self.headers) + if node_data.body.type == "json" and not content_type_is_set: + self.headers["Content-Type"] = "application/json" + elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: + self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: body = self._to_dict(body_data) - if node_data.body.type == 'form-data': - self.files = {k: ('', v) for k, v in body.items()} - random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f'----WebKitFormBoundary{random_str(16)}' + if node_data.body.type == "form-data": + self.files = {k: ("", v) for k, v in body.items()} + random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) + self.boundary = f"----WebKitFormBoundary{random_str(16)}" - self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}' + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ['json', 'raw-text']: + elif node_data.body.type in ["json", "raw-text"]: self.body = body_data - elif node_data.body.type == 'none': - self.body = '' + elif node_data.body.type == "none": + self.body = "" self.variable_selectors = ( server_url_variable_selectors @@ -202,23 +202,23 @@ def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} - if self.authorization.type == 'api-key': + if self.authorization.type == "api-key": if self.authorization.config is None: - raise ValueError('self.authorization config is required') + raise ValueError("self.authorization config is required") if authorization.config is None: - raise ValueError('authorization config is required') + raise ValueError("authorization config is required") if self.authorization.config.api_key is None: - raise ValueError('api_key is required') + raise ValueError("api_key is required") if not authorization.config.header: - authorization.config.header = 'Authorization' + authorization.config.header = "Authorization" - if self.authorization.config.type == 'bearer': - headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' - elif self.authorization.config.type == 'basic': - headers[authorization.config.header] = f'Basic {authorization.config.api_key}' - elif self.authorization.config.type == 'custom': + if self.authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.authorization.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key return headers @@ -230,10 +230,13 @@ def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutor if isinstance(response, httpx.Response): executor_response = HttpExecutorResponse(response) else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") - threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) if executor_response.size > threshold_size: raise ValueError( f'{"File" if executor_response.is_file else "Text"} size is too large,' @@ -248,17 +251,17 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: do http request depending on api bundle """ kwargs = { - 'url': self.server_url, - 'headers': headers, - 'params': self.params, - 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), - 'follow_redirects': True, + "url": self.server_url, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, } - if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): + if self.method in ("get", "head", "post", "put", "delete", "patch"): response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") return response def invoke(self) -> HttpExecutorResponse: @@ -280,15 +283,15 @@ def to_raw_request(self) -> str: """ server_url = self.server_url if self.params: - server_url += f'?{urlencode(self.params)}' + server_url += f"?{urlencode(self.params)}" - raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" headers = self._assembling_headers() for k, v in headers.items(): # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' + if self.authorization.type == "api-key": + authorization_header = "Authorization" if self.authorization.config and self.authorization.config.header: authorization_header = self.authorization.config.header @@ -296,21 +299,21 @@ def to_raw_request(self) -> str: raw_request += f'{k}: {"*" * len(v)}\n' continue - raw_request += f'{k}: {v}\n' + raw_request += f"{k}: {v}\n" - raw_request += '\n' + raw_request += "\n" # if files, use multipart/form-data with boundary if self.files: boundary = self.boundary - raw_request += f'--{boundary}' + raw_request += f"--{boundary}" for k, v in self.files.items(): raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f'{v[1]}\n' - raw_request += f'--{boundary}' - raw_request += '--' + raw_request += f"{v[1]}\n" + raw_request += f"--{boundary}" + raw_request += "--" else: - raw_request += self.body or '' + raw_request += self.body or "" return raw_request @@ -328,9 +331,9 @@ def _format_template( for variable_selector in variable_selectors: variable = variable_pool.get_any(variable_selector.value_selector) if variable is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace('\n', '\\n') + value = variable.replace('"', '\\"').replace("\n", "\\n") else: value = variable variable_value_mapping[variable_selector.variable] = value diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 3f68c8b1d03ce5..cd408191263f12 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -31,18 +31,18 @@ class HttpRequestNode(BaseNode): @classmethod def get_default_config(cls, filters: dict | None = None) -> dict: return { - 'type': 'http-request', - 'config': { - 'method': 'get', - 'authorization': { - 'type': 'no-auth', + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", }, - 'body': {'type': 'none'}, - 'timeout': { + "body": {"type": "none"}, + "timeout": { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - 'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - 'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } @@ -52,9 +52,8 @@ def _run(self) -> NodeRunResult: # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, - variable_pool=self.graph_runtime_state.variable_pool - ).text + template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None @@ -62,7 +61,7 @@ def _run(self) -> NodeRunResult: http_executor = HttpExecutor( node_data=node_data, timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool + variable_pool=self.graph_runtime_state.variable_pool, ) # invoke http executor @@ -71,7 +70,7 @@ def _run(self) -> NodeRunResult: process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -84,13 +83,13 @@ def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ - 'status_code': response.status_code, - 'body': response.content if not files else '', - 'headers': response.headers, - 'files': files, + "status_code": response.status_code, + "body": response.content if not files else "", + "headers": response.headers, + "files": files, }, process_data={ - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), }, ) @@ -107,10 +106,7 @@ def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeo @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -126,11 +122,11 @@ def _extract_variable_selector_to_variable_mapping( variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: - logging.exception(f'Failed to extract variable selector to variable mapping: {e}') + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: @@ -144,7 +140,7 @@ def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVa # extract filename from url filename = path.basename(url) # extract extension if possible - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 338277ace1a150..54c1081fd3b9c5 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -15,6 +15,7 @@ class Case(BaseModel): """ Case entity representing a single logical condition group """ + case_id: str logical_operator: Literal["and", "or"] conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index ca87eecd0d4c5f..5b4737c6e59116 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -20,13 +20,9 @@ def _run(self) -> NodeRunResult: node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs: dict[str, list] = { - "conditions": [] - } + node_inputs: dict[str, list] = {"conditions": []} - process_datas: dict[str, list] = { - "condition_results": [] - } + process_datas: dict[str, list] = {"condition_results": []} input_conditions = [] final_result = False @@ -37,8 +33,7 @@ def _run(self) -> NodeRunResult: if node_data.cases: for case in node_data.cases: input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions ) # Apply the logical operator for the current case @@ -60,8 +55,7 @@ def _run(self) -> NodeRunResult: else: # Fallback to old structure if cases are not defined input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=node_data.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) @@ -69,21 +63,14 @@ def _run(self) -> NodeRunResult: selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( - { - "group": "default", - "results": group_result, - "final_result": final_result - } + {"group": "default", "results": group_result, "final_result": final_result} ) node_inputs["conditions"] = input_conditions except Exception as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=node_inputs, - process_data=process_datas, - error=str(e) + status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) outputs = {"result": final_result, "selected_case_id": selected_case_id} @@ -93,17 +80,14 @@ def _run(self) -> NodeRunResult: inputs=node_inputs, process_data=process_datas, edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' - outputs=outputs + outputs=outputs, ) return data @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 5fc5a827aedd08..3c2c189159cbef 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector + + parent_loop_id: Optional[str] = None # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector class IterationStartNodeData(BaseNodeData): """ Iteration Start Node Data. """ + pass + class IterationState(BaseIterationState): """ Iteration State. """ + outputs: list[Any] = None current_output: Optional[Any] = None @@ -29,6 +33,7 @@ class MetaData(BaseIterationState.MetaData): """ Data. """ + iterator_length: int def get_last_output(self) -> Optional[Any]: @@ -38,9 +43,9 @@ def get_last_output(self) -> Optional[Any]: if self.outputs: return self.outputs[-1] return None - + def get_current_output(self) -> Optional[Any]: """ Get current output. """ - return self.current_output \ No newline at end of file + return self.current_output diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 93eff16c339558..77b14e36a19f83 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -33,6 +33,7 @@ class IterationNode(BaseNode): """ Iteration Node. """ + _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION @@ -45,31 +46,26 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: if not iterator_list_segment: raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - + iterator_list_value = iterator_list_segment.to_object() if not isinstance(iterator_list_value, list): raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - inputs = { - "iterator_selector": iterator_list_value - } + inputs = {"iterator_selector": iterator_list_value} graph_config = self.graph_config - + if not self.node_data.start_node_id: - raise ValueError(f'field start_node_id in iteration {self.node_id} not found') + raise ValueError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self.node_data.start_node_id # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=root_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') + raise ValueError("iteration graph not found") leaf_node_ids = iteration_graph.get_leaf_node_ids() iteration_leaf_node_ids = [] @@ -97,26 +93,21 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: Condition( variable_selector=[self.node_id, "index"], comparison_operator="<", - value=str(len(iterator_list_value)) + value=str(len(iterator_list_value)), ) - ] - ) + ], + ), ) variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool - variable_pool.add( - [self.node_id, 'index'], - 0 - ) - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[0] - ) + variable_pool.add([self.node_id, "index"], 0) + variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine from core.workflow.graph_engine.graph_engine import GraphEngine + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -130,7 +121,7 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: graph_config=graph_config, variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) start_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -142,10 +133,8 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - metadata={ - "iterator_length": len(iterator_list_value) - }, - predecessor_node_id=self.previous_node_id + metadata={"iterator_length": len(iterator_list_value)}, + predecessor_node_id=self.previous_node_id, ) yield IterationRunNextEvent( @@ -154,7 +143,7 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=0, - pre_iteration_output=None + pre_iteration_output=None, ) outputs: list[Any] = [] @@ -176,7 +165,9 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: if NodeRunMetadataKey.ITERATION_ID not in metadata: metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) event.route_node_state.node_run_result.metadata = metadata yield event @@ -192,21 +183,15 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: variable_pool.remove_node(node_id) # move to next iteration - current_index = variable_pool.get([self.node_id, 'index']) + current_index = variable_pool.get([self.node_id, "index"]) if current_index is None: - raise ValueError(f'iteration {self.node_id} current index not found') + raise ValueError(f"iteration {self.node_id} current index not found") next_index = int(current_index.to_object()) + 1 - variable_pool.add( - [self.node_id, 'index'], - next_index - ) + variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[next_index] - ) + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) yield IterationRunNextEvent( iteration_id=self.id, @@ -214,8 +199,9 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=next_index, - pre_iteration_output=jsonable_encoder( - current_iteration_output) if current_iteration_output else None + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): @@ -227,13 +213,9 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) @@ -255,21 +237,14 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - } + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(outputs) - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} ) ) except Exception as e: @@ -282,16 +257,11 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=str(e), ) - yield RunCompletedEvent( run_result=NodeRunResult( @@ -301,15 +271,12 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: ) finally: # remove iteration variable (item, index) from variable pool after iteration run completed - variable_pool.remove([self.node_id, 'index']) - variable_pool.remove([self.node_id, 'item']) - + variable_pool.remove([self.node_id, "index"]) + variable_pool.remove([self.node_id, "item"]) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -319,36 +286,33 @@ def _extract_variable_selector_to_variable_mapping( :return: """ variable_mapping = { - f'{node_id}.input_selector': node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=node_data.start_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') - + raise ValueError("iteration graph not found") + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): - if sub_node_config.get('data', {}).get('iteration_id') != node_id: + if sub_node_config.get("data", {}).get("iteration_id") != node_id: continue # variable selector to variable mapping try: # Get node class from core.workflow.nodes.node_mapping import node_classes - node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + + node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: continue node_cls = cast(BaseNode, node_cls) - + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - config=sub_node_config + graph_config=graph_config, config=sub_node_config ) sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) except NotImplementedError: @@ -356,7 +320,8 @@ def _extract_variable_selector_to_variable_mapping( # remove iteration variables sub_node_variable_mapping = { - sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + sub_node_id + "." + key: value + for key, value in sub_node_variable_mapping.items() if value[0] != node_id } @@ -364,8 +329,7 @@ def _extract_variable_selector_to_variable_mapping( # remove variable out from iteration variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_graph.node_ids + key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids } - + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 25044cf3eb99b2..88b9665ac6a673 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -11,6 +11,7 @@ class IterationStartNode(BaseNode): """ Iteration Start Node. """ + _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START @@ -18,16 +19,11 @@ def _run(self) -> NodeRunResult: """ Run the node. """ - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED - ) - + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 7cf392277caad6..1cd88039b12e30 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel): """ Reranking Model Config. """ + provider: str model: str @@ -17,6 +18,7 @@ class VectorSetting(BaseModel): """ Vector Setting. """ + vector_weight: float embedding_provider_name: str embedding_model_name: str @@ -26,6 +28,7 @@ class KeywordSetting(BaseModel): """ Keyword Setting. """ + keyword_weight: float @@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel): """ Weighted score Config. """ + vector_setting: VectorSetting keyword_setting: KeywordSetting @@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ + top_k: int score_threshold: Optional[float] = None - reranking_mode: str = 'reranking_model' + reranking_mode: str = "reranking_model" reranking_enable: bool = True reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. """ + model: ModelConfig @@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'knowledge-retrieval' + + type: str = "knowledge-retrieval" query_variable_selector: list[str] dataset_ids: list[str] - retrieval_mode: Literal['single', 'multiple'] + retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2d1ac4731c269f..19deca162aed24 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,14 +24,11 @@ logger = logging.getLogger(__name__) default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -45,62 +42,47 @@ def _run(self) -> NodeRunResult: # extract variables variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable - variables = { - 'query': query - } + variables = {"query": query} if not query: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Query is required." + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) # retrieve knowledge try: - results = self._fetch_dataset_retriever( - node_data=node_data, query=query - ) - outputs = { - 'result': results - } + results = self._fetch_dataset_retriever(node_data=node_data, query=query) + outputs = {"result": results} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=None, - outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) except Exception as e: logger.exception("Error when running knowledge retrieval node") - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ - dict[str, Any]]: + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] dataset_ids = node_data.dataset_ids # Subquery: Count the number of available documents for each dataset - subquery = db.session.query( - Document.dataset_id, - func.count(Document.id).label('available_document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids) - ).group_by(Document.dataset_id).having( - func.count(Document.id) > 0 - ).subquery() + subquery = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids), + ) + .group_by(Document.dataset_id) + .having(func.count(Document.id) > 0) + .subquery() + ) - results = db.session.query(Dataset).join( - subquery, Dataset.id == subquery.c.dataset_id - ).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id.in_(dataset_ids) - ).all() + results = ( + db.session.query(Dataset) + .join(subquery, Dataset.id == subquery.c.dataset_id) + .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .all() + ) for dataset in results: # pass if dataset is not available @@ -117,16 +99,14 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: model_type_instance = cast(LargeLanguageModel, model_type_instance) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if model_schema: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER all_documents = dataset_retrieval.single_retrieve( available_datasets=available_datasets, @@ -137,111 +117,111 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: query=query, model_config=model_config, model_instance=model_instance, - planning_strategy=planning_strategy + planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': + if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, } weights = None - elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None weights = { - 'vector_setting': { + "vector_setting": { "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, }, - 'keyword_setting': { + "keyword_setting": { "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - } + }, } else: reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, - self.user_from.value, - available_datasets, query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, - ) + all_documents = dataset_retrieval.multiple_retrieve( + self.app_id, + self.tenant_id, + self.user_id, + self.user_from.value, + available_datasets, + query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) context_list = [] if all_documents: document_score_list = {} page_number_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] # both 'page' and 'score' are metadata fields - if item.metadata.get('page'): - page_number_list[item.metadata['doc_id']] = item.metadata['page'] + if item.metadata.get("page"): + page_number_list[item.metadata["doc_id"]] = item.metadata["page"] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() resource_number = 1 if dataset and document: source = { - 'metadata': { - '_source': 'knowledge', - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'document_data_source_type': document.data_source_type, - 'page': page_number_list.get(segment.index_node_id, None), - 'segment_id': segment.id, - 'retriever_from': 'workflow', - 'score': document_score_list.get(segment.index_node_id, None), - 'segment_hit_count': segment.hit_count, - 'segment_word_count': segment.word_count, - 'segment_position': segment.position, - 'segment_index_node_hash': segment.index_node_hash, + "metadata": { + "_source": "knowledge", + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "document_data_source_type": document.data_source_type, + "page": page_number_list.get(segment.index_node_id, None), + "segment_id": segment.id, + "retriever_from": "workflow", + "score": document_score_list.get(segment.index_node_id, None), + "segment_hit_count": segment.hit_count, + "segment_word_count": segment.word_count, + "segment_position": segment.position, + "segment_index_node_hash": segment.index_node_hash, }, - 'title': document.name + "title": document.name, } if segment.answer: - source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: - source['content'] = segment.get_sign_content() + source["content"] = segment.get_sign_content() context_list.append(source) resource_number += 1 return context_list @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: KnowledgeRetrievalNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -251,11 +231,12 @@ def _extract_variable_selector_to_variable_mapping( :return: """ variable_mapping = {} - variable_mapping[node_id + '.query'] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data: KnowledgeRetrievalNodeData + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -266,10 +247,7 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -280,8 +258,7 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -297,19 +274,16 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ # model config completion_params = node_data.single_retrieval_config.model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1e48a10bc77012..93ee0ac2503b1f 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -11,6 +11,7 @@ class ModelConfig(BaseModel): """ Model Config. """ + provider: str name: str mode: str @@ -21,6 +22,7 @@ class ContextConfig(BaseModel): """ Context Config. """ + enabled: bool variable_selector: Optional[list[str]] = None @@ -29,37 +31,47 @@ class VisionConfig(BaseModel): """ Vision Config. """ + class Configs(BaseModel): """ Configs. """ - detail: Literal['low', 'high'] + + detail: Literal["low", "high"] enabled: bool configs: Optional[Configs] = None + class PromptConfig(BaseModel): """ Prompt Config. """ + jinja2_variables: Optional[list[VariableSelector]] = None + class LLMNodeChatModelMessage(ChatModelMessage): """ LLM Node Chat Model Message. """ + jinja2_text: Optional[str] = None + class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): """ LLM Node Chat Model Prompt Template. """ + jinja2_text: Optional[str] = None + class LLMNodeData(BaseNodeData): """ LLM Node Data. """ + model: ModelConfig prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] prompt_config: Optional[PromptConfig] = None diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index f26ec1b0b54269..6dfd27861eefc4 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -45,11 +45,11 @@ from core.file.file_obj import FileVar - class ModelInvokeCompleted(BaseModel): """ Model invoke completed """ + text: str usage: LLMUsage finish_reason: Optional[str] = None @@ -89,7 +89,7 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: files = self._fetch_files(node_data, variable_pool) if files: - node_inputs['#files#'] = [file.to_dict() for file in files] + node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value generator = self._fetch_context(node_data, variable_pool) @@ -100,7 +100,7 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: yield event if context: - node_inputs['#context#'] = context # type: ignore + node_inputs["#context#"] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -111,24 +111,22 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) - if node_data.memory else None, + query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'model_provider': model_config.provider, - 'model_name': model_config.model, + "model_provider": model_config.provider, + "model_name": model_config.model, } # handle invoke result @@ -136,10 +134,10 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, - stop=stop + stop=stop, ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -156,16 +154,12 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, - process_data=process_data + process_data=process_data, ) ) return - outputs = { - 'text': result_text, - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} yield RunCompletedEvent( run_result=NodeRunResult( @@ -176,17 +170,19 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -206,9 +202,7 @@ def _invoke_llm(self, node_data_model: ModelConfig, ) # handle invoke result - generator = self._handle_invoke_result( - invoke_result=invoke_result - ) + generator = self._handle_invoke_result(invoke_result=invoke_result) usage = LLMUsage.empty_usage() for event in generator: @@ -219,8 +213,9 @@ def _invoke_llm(self, node_data_model: ModelConfig, # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result @@ -231,17 +226,14 @@ def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ model = None prompt_messages: list[PromptMessage] = [] - full_text = '' + full_text = "" usage = None finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text - yield RunStreamChunkEvent( - chunk_content=text, - from_variable_selector=[self.node_id, 'text'] - ) + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if not model: model = result.model @@ -258,15 +250,11 @@ def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted( - text=full_text, - usage=usage, - finish_reason=finish_reason - ) + yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + def _transform_chat_messages( + self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -275,13 +263,13 @@ def _transform_chat_messages(self, """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2' and messages.jinja2_text: + if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2' and message.jinja2_text: + if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages @@ -300,17 +288,15 @@ def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePoo for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ Parse dict into string """ # check if it's a context structure - if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: - return d['content'] + if "metadata" in d and "_source" in d["metadata"] and "content" in d: + return d["content"] # else, parse the dict try: @@ -321,7 +307,7 @@ def parse_dict(d: dict) -> str: if isinstance(value, str): value = value elif isinstance(value, list): - result = '' + result = "" for item in value: if isinstance(item, dict): result += parse_dict(item) @@ -331,7 +317,7 @@ def parse_dict(d: dict) -> str: result += str(item) else: result += str(item) - result += '\n' + result += "\n" value = result.strip() elif isinstance(value, dict): value = parse_dict(value) @@ -366,18 +352,19 @@ def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> for variable_selector in variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value @@ -393,7 +380,7 @@ def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> l if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) + files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) if not files: return [] @@ -415,29 +402,25 @@ def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - yield RunRetrieverResourceEvent( - retriever_resources=[], - context=context_value - ) + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) elif isinstance(context_value, list): - context_str = '' + context_str = "" original_retriever_resource = [] for item in context_value: if isinstance(item, str): - context_str += item + '\n' + context_str += item + "\n" else: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') + if "content" not in item: + raise ValueError(f"Invalid context structure: {item}") - context_str += item['content'] + '\n' + context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip() + retriever_resources=original_retriever_resource, context=context_str.strip() ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: @@ -446,34 +429,38 @@ def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optiona :param context_dict: context dict :return: """ - if ('metadata' in context_dict and '_source' in context_dict['metadata'] - and context_dict['metadata']['_source'] == 'knowledge'): - metadata = context_dict.get('metadata', {}) + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) source = { - 'position': metadata.get('position'), - 'dataset_id': metadata.get('dataset_id'), - 'dataset_name': metadata.get('dataset_name'), - 'document_id': metadata.get('document_id'), - 'document_name': metadata.get('document_name'), - 'data_source_type': metadata.get('document_data_source_type'), - 'segment_id': metadata.get('segment_id'), - 'retriever_from': metadata.get('retriever_from'), - 'score': metadata.get('score'), - 'hit_count': metadata.get('segment_hit_count'), - 'word_count': metadata.get('segment_word_count'), - 'segment_position': metadata.get('segment_position'), - 'index_node_hash': metadata.get('segment_index_node_hash'), - 'content': context_dict.get('content'), - 'page': metadata.get('page'), + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + "page": metadata.get("page"), } return source return None - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data_model: node data model @@ -484,10 +471,7 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -498,8 +482,7 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -515,19 +498,16 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ # model config completion_params = node_data_model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data_model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") @@ -543,9 +523,9 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ stop=stop, ) - def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: """ Fetch memory :param node_data_memory: node data memory @@ -556,35 +536,35 @@ def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) if not conversation: return None - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) return memory - def _fetch_prompt_messages(self, node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt_messages( + self, + node_data: LLMNodeData, + query: Optional[str], + query_prompt_template: Optional[str], + inputs: dict[str, str], + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt messages :param node_data: node data @@ -601,7 +581,7 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=node_data.memory, @@ -621,8 +601,11 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( - content_item, ImagePromptMessageContent): + if ( + vision_enabled + and content_item.type == PromptMessageContentType.IMAGE + and isinstance(content_item, ImagePromptMessageContent) + ): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -632,15 +615,18 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, if len(prompt_message_content) > 1: prompt_message.content = prompt_message_content - elif (len(prompt_message_content) == 1 - and prompt_message_content[0].type == PromptMessageContentType.TEXT): + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): prompt_message.content = prompt_message_content[0].data filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError("No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding.") + raise ValueError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) return filtered_prompt_messages, stop @@ -678,7 +664,7 @@ def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_instance.model: + if "gpt-4" in model_instance.model: used_quota = 20 else: used_quota = 1 @@ -689,16 +675,13 @@ def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: Provider.provider_name == model_instance.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -712,11 +695,11 @@ def _extract_variable_selector_to_variable_mapping( variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type != 'jinja2': + if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - if prompt_template.edition_type != 'jinja2': + if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() @@ -726,39 +709,38 @@ def _extract_variable_selector_to_variable_mapping( memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if node_data.context.enabled: - variable_mapping['#context#'] = node_data.context.variable_selector + variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type == 'jinja2': + if prompt.edition_type == "jinja2": enable_jinja = True break else: - if prompt_template.edition_type == 'jinja2': + if prompt_template.edition_type == "jinja2": enable_jinja = True if enable_jinja: for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping @@ -775,26 +757,19 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: "prompt_templates": { "chat_model": { "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant.", - "edition_type": "basic" - } + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic" + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", }, - "stop": ["Human:"] - } + "stop": ["Human:"], + }, } - } + }, } diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 8a5684551ef7f3..a8a0debe64284b 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,3 @@ - from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState @@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData): Loop Node Data. """ + class LoopState(BaseIterationState): """ Loop State. - """ \ No newline at end of file + """ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 526404e30d0d37..fbc68b79cb27de 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -10,6 +10,7 @@ class LoopNode(BaseNode): """ Loop Node. """ + _node_data_cls = LoopNodeData _node_type = NodeType.LOOP @@ -21,14 +22,16 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ Get conditions. """ - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: return [] # TODO waiting for implementation - return [Condition( - variable_selector=[node_id, 'index'], - comparison_operator="≤", - value_type="value_selector", - value_selector=[] - )] + return [ + Condition( + variable_selector=[node_id, "index"], + comparison_operator="≤", + value_type="value_selector", + value_selector=[], + ) + ] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 7bb123b1267c91..802ed31e27a42d 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -8,47 +8,52 @@ class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str completion_params: dict[str, Any] = {} + class ParameterConfig(BaseModel): """ Parameter Config. """ + name: str - type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] options: Optional[list[str]] = None description: str required: bool - @field_validator('name', mode='before') + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: if not value: - raise ValueError('Parameter name is required') - if value in ['__reason', '__is_success']: - raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + raise ValueError("Parameter name is required") + if value in ["__reason", "__is_success"]: + raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. """ + model: ModelConfig query: list[str] parameters: list[ParameterConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None - reasoning_mode: Literal['function_call', 'prompt'] + reasoning_mode: Literal["function_call", "prompt"] - @field_validator('reasoning_mode', mode='before') + @field_validator("reasoning_mode", mode="before") @classmethod def set_reasoning_mode(cls, v) -> str: - return v or 'function_call' + return v or "function_call" def get_parameter_json_schema(self) -> dict: """ @@ -56,32 +61,26 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = { - 'type': 'object', - 'properties': {}, - 'required': [] - } + parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = { - 'description': parameter.description - } - - if parameter.type in ['string', 'select']: - parameter_schema['type'] = 'string' - elif parameter.type.startswith('array'): - parameter_schema['type'] = 'array' + parameter_schema = {"description": parameter.description} + + if parameter.type in ["string", "select"]: + parameter_schema["type"] = "string" + elif parameter.type.startswith("array"): + parameter_schema["type"] = "array" nested_type = parameter.type[6:-1] - parameter_schema['items'] = {'type': nested_type} + parameter_schema["items"] = {"type": nested_type} else: - parameter_schema['type'] = parameter.type + parameter_schema["type"] = parameter.type + + if parameter.type == "select": + parameter_schema["enum"] = parameter.options - if parameter.type == 'select': - parameter_schema['enum'] = parameter.options + parameters["properties"][parameter.name] = parameter_schema - parameters['properties'][parameter.name] = parameter_schema - if parameter.required: - parameters['required'].append(parameter.name) + parameters["required"].append(parameter.name) - return parameters \ No newline at end of file + return parameters diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2e65705f10f2ce..131d26b19eaf31 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -45,6 +45,7 @@ class ParameterExtractorNode(LLMNode): """ Parameter Extractor Node. """ + _node_data_cls = ParameterExtractorNodeData _node_type = NodeType.PARAMETER_EXTRACTOR @@ -57,11 +58,8 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "stop": ["Human:"] + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "stop": ["Human:"], } } } @@ -78,9 +76,9 @@ def _run(self) -> NodeRunResult: query = variable inputs = { - 'query': query, - 'parameters': jsonable_encoder(node_data.parameters), - 'instruction': jsonable_encoder(node_data.instruction), + "query": query, + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), } model_instance, model_config = self._fetch_model_config(node_data.model) @@ -95,30 +93,29 @@ def _run(self) -> NodeRunResult: # fetch memory memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': - # use function call + if ( + set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} + and node_data.reasoning_mode == "function_call" + ): + # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, - query, - self.graph_runtime_state.variable_pool, - model_config, - memory) + prompt_messages = self._generate_prompt_engineering_prompt( + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + ) prompt_message_tools = [] process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': None, - 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - 'tool_call': None, + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + "tool_call": None, } try: @@ -129,20 +126,17 @@ def _run(self) -> NodeRunResult: tools=prompt_message_tools, stop=model_config.stop, ) - process_data['usage'] = jsonable_encoder(usage) - process_data['tool_call'] = jsonable_encoder(tool_call) - process_data['llm_text'] = text + process_data["usage"] = jsonable_encoder(usage) + process_data["tool_call"] = jsonable_encoder(tool_call) + process_data["llm_text"] = text except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 0, - '__reason': str(e) - }, + outputs={"__is_success": 0, "__reason": str(e)}, error=str(e), - metadata={} + metadata={}, ) error = None @@ -167,24 +161,23 @@ def _run(self) -> NodeRunResult: status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 1 if not error else 0, - '__reason': error, - **result - }, + outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str], + ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: """ Invoke large language model :param node_data_model: node data model @@ -217,32 +210,35 @@ def _invoke_llm(self, node_data_model: ModelConfig, return text, usage, tool_call - def _generate_function_call_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + def _generate_function_call_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps( - node_data.get_parameter_json_schema())) + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( + content=query, structure=json.dumps(node_data.get_parameter_json_schema()) + ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_function_calling_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -255,124 +251,125 @@ def _generate_function_call_prompt(self, example_messages = [] for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: id = uuid.uuid4().hex - example_messages.extend([ - UserPromptMessage(content=example['user']['query']), - AssistantPromptMessage( - content=example['assistant']['text'], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example['assistant']['function_call']['name'], - arguments=json.dumps(example['assistant']['function_call']['parameters'] - ) - )) - ] - ), - ToolPromptMessage( - content='Great! You have called the function with the correct parameters.', - tool_call_id=id - ), - AssistantPromptMessage( - content='I have extracted the parameters, let\'s move on.', - ) - ]) + example_messages.extend( + [ + UserPromptMessage(content=example["user"]["query"]), + AssistantPromptMessage( + content=example["assistant"]["text"], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example["assistant"]["function_call"]["name"], + arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), + ), + ) + ], + ), + ToolPromptMessage( + content="Great! You have called the function with the correct parameters.", tool_call_id=id + ), + AssistantPromptMessage( + content="I have extracted the parameters, let's move on.", + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) # generate tool tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, - description='Extract parameters from the natural language text', + description="Extract parameters from the natural language text", parameters=node_data.get_parameter_json_schema(), ) return prompt_messages, [tool] - def _generate_prompt_engineering_prompt(self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_prompt( + self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate prompt engineering prompt. """ model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) else: raise ValueError(f"Invalid model mode: {model_mode}") - def _generate_prompt_engineering_completion_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_completion_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, - inputs={ - 'structure': json.dumps(node_data.get_parameter_json_schema()) - }, - query='', + inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _generate_prompt_engineering_chat_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_chat_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") prompt_template = self._get_prompt_engineering_prompt_template( node_data, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), - text=query + structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, memory, rest_token + variable_pool, + memory, + rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -384,18 +381,23 @@ def _generate_prompt_engineering_chat_prompt(self, # add example messages before last user message example_messages = [] for example in CHAT_EXAMPLE: - example_messages.extend([ - UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example['user']['json']), - text=example['user']['query'], - )), - AssistantPromptMessage( - content=json.dumps(example['assistant']['json']), - ) - ]) + example_messages.extend( + [ + UserPromptMessage( + content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example["user"]["json"]), + text=example["user"]["query"], + ) + ), + AssistantPromptMessage( + content=json.dumps(example["assistant"]["json"]), + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) return prompt_messages @@ -410,28 +412,28 @@ def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> di if parameter.required and parameter.name not in result: raise ValueError(f"Parameter {parameter.name} is required") - if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: raise ValueError(f"Invalid `select` value for parameter {parameter.name}") - if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): raise ValueError(f"Invalid `number` value for parameter {parameter.name}") - if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") - if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + if parameter.type == "string" and not isinstance(result.get(parameter.name), str): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") - if parameter.type.startswith('array'): + if parameter.type.startswith("array"): if not isinstance(result.get(parameter.name), list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in result.get(parameter.name): - if nested_type == 'number' and not isinstance(item, int | float): + if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == 'string' and not isinstance(item, str): + if nested_type == "string" and not isinstance(item, str): raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == 'object' and not isinstance(item, dict): + if nested_type == "object" and not isinstance(item, dict): raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result @@ -443,12 +445,12 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d for parameter in data.parameters: if parameter.name in result: # transform value - if parameter.type == 'number': + if parameter.type == "number": if isinstance(result[parameter.name], int | float): transformed_result[parameter.name] = result[parameter.name] elif isinstance(result[parameter.name], str): try: - if '.' in result[parameter.name]: + if "." in result[parameter.name]: result[parameter.name] = float(result[parameter.name]) else: result[parameter.name] = int(result[parameter.name]) @@ -465,40 +467,40 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ['string', 'select']: + elif parameter.type in ["string", "select"]: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith('array'): + elif parameter.type.startswith("array"): if isinstance(result[parameter.name], list): nested_type = parameter.type[6:-1] transformed_result[parameter.name] = [] for item in result[parameter.name]: - if nested_type == 'number': + if nested_type == "number": if isinstance(item, int | float): transformed_result[parameter.name].append(item) elif isinstance(item, str): try: - if '.' in item: + if "." in item: transformed_result[parameter.name].append(float(item)) else: transformed_result[parameter.name].append(int(item)) except ValueError: pass - elif nested_type == 'string': + elif nested_type == "string": if isinstance(item, str): transformed_result[parameter.name].append(item) - elif nested_type == 'object': + elif nested_type == "object": if isinstance(item, dict): transformed_result[parameter.name].append(item) if parameter.name not in transformed_result: - if parameter.type == 'number': + if parameter.type == "number": transformed_result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ['string', 'select']: - transformed_result[parameter.name] = '' - elif parameter.type.startswith('array'): + elif parameter.type in ["string", "select"]: + transformed_result[parameter.name] = "" + elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] return transformed_result @@ -514,24 +516,24 @@ def extract_json(text): """ stack = [] for i, c in enumerate(text): - if c == '{' or c == '[': + if c == "{" or c == "[": stack.append(c) - elif c == '}' or c == ']': + elif c == "}" or c == "]": # check if stack is empty if not stack: return text[:i] # check if the last element in stack is matching - if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): stack.pop() if not stack: - return text[:i + 1] + return text[: i + 1] else: return text[:i] return None # extract json from the text for idx in range(len(result)): - if result[idx] == '{' or result[idx] == '[': + if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: try: @@ -554,12 +556,12 @@ def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ result = {} for parameter in data.parameters: - if parameter.type == 'number': + if parameter.type == "number": result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ['string', 'select']: - result[parameter.name] = '' + elif parameter.type in ["string", "select"]: + result[parameter.name] = "" return result @@ -575,71 +577,76 @@ def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: + def _get_function_calling_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: raise ValueError(f"Model mode {model_mode} not support.") - def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: - + def _get_prompt_engineering_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, - text=input_text, - instruction=instruction) - .replace('{γγγ', '') - .replace('}γγγ', '') + text=COMPLETION_GENERATE_JSON_PROMPT.format( + histories=memory_str, text=input_text, instruction=instruction + ) + .replace("{γγγ", "") + .replace("}γγγ", "") ) else: raise ValueError(f"Model mode {model_mode} not support.") - def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) model_instance, model_config = self._fetch_model_config(node_data.model) @@ -659,12 +666,12 @@ def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: st prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 @@ -673,26 +680,28 @@ def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: st model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - prompt_messages - ) + 1000 # add 1000 to ensure tool call messages + curr_message_tokens = ( + model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + ) # add 1000 to ensure tool call messages max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config. """ @@ -703,10 +712,7 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -715,17 +721,13 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ - variable_mapping = { - 'query': node_data.query - } + variable_mapping = {"query": node_data.query} if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 499c58d505832c..c63fded4d02573 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,4 +1,4 @@ -FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' +FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. ### Task @@ -35,61 +35,48 @@ """ -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + }, }, + "required": ["location"], }, - 'required': ['location'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'location': 'San Francisco' - } - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'food': 'apple pie' - } - } - } -}] +] COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: Some extra information are provided below, I should always follow the instructions as possible as I can. @@ -161,46 +148,33 @@ """ -CHAT_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'json': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True - } +CHAT_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "json": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + } + }, + "required": ["location"], }, - 'required': ['location'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'location': 'San Francisco' - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'json': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "json": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'result': 'apple pie' - } - } -}] \ No newline at end of file +] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c0b0a8b6968ead..40f7ce7582fdbc 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -8,8 +8,9 @@ class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -20,6 +21,7 @@ class ClassConfig(BaseModel): """ Class Config. """ + id: str name: str @@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ + query_variable_selector: list[str] - type: str = 'question-classifier' + type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ecab8db9b62b56..d860f848ec3857 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -45,34 +45,25 @@ def _run(self) -> NodeRunResult: # extract variables variable = variable_pool.get(node_data.query_variable_selector) query = variable.value if variable else None - variables = { - 'query': query - } + variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' + instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( - node_data=node_data, - context='', - query=query, - memory=memory, - model_config=model_config + node_data=node_data, context="", query=query, memory=memory, model_config=model_config ) # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -87,8 +78,8 @@ def _run(self) -> NodeRunResult: try: result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) - if 'category_name' in result_text_json and 'category_id' in result_text_json: - category_id_result = result_text_json['category_id'] + if "category_name" in result_text_json and "category_id" in result_text_json: + category_id_result = result_text_json["category_id"] classes = node_data.classes classes_map = {class_.id: class_.name for class_ in classes} category_ids = [_class.id for _class in classes] @@ -100,17 +91,14 @@ def _run(self) -> NodeRunResult: logging.error(f"Failed to parse result text: {result_text}") try: process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } - outputs = { - 'class_name': category_name + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, } + outputs = {"class_name": category_name} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -121,9 +109,9 @@ def _run(self) -> NodeRunResult: metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) except ValueError as e: @@ -134,17 +122,14 @@ def _run(self) -> NodeRunResult: metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -153,7 +138,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ - variable_mapping = {'query': node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) @@ -161,10 +146,8 @@ def _extract_variable_selector_to_variable_mapping( for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } - + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -174,19 +157,16 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: :param filters: filter by node config parameters. :return: """ - return { - "type": "question-classifier", - "config": { - "instructions": "" - } - } + return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt( + self, + node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt :param node_data: node data @@ -202,118 +182,122 @@ def _fetch_prompt(self, node_data: QuestionClassifierNodeData, prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: QuestionClassifierNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + def _get_prompt_template( + self, + node_data: QuestionClassifierNodeData, + query: str, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: - category = { - 'category_id': class_.id, - 'category_name': class_.name - } + category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else '' + instruction = node_data.instruction if node_data.instruction else "" input_text = query - memory_str = '' + memory_str = "" if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) prompt_messages = [] if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_1 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) assistant_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) user_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_2 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) assistant_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction) + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( + input_text=input_text, + categories=json.dumps(categories, ensure_ascii=False), + classification_instructions=instruction, + ), ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, - input_text=input_text, - categories=json.dumps(categories), - classification_instructions=instruction, - ensure_ascii=False) + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( + histories=memory_str, + input_text=input_text, + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False, + ) ) else: @@ -329,14 +313,12 @@ def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> variable = variable_pool.get(variable_selector.value_selector) variable_value = variable.value if variable else None if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - instruction = prompt_template.format( - prompt_inputs - ) + instruction = prompt_template.format(prompt_inputs) return instruction diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index e0de148cc2c08c..581f9869221632 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -1,5 +1,3 @@ - - QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Job Description', You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index b81ce15bd74c96..11d2ebe5ddb2cf 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -10,4 +10,5 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 69cdec6a9260a3..96c887c58d2c62 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping, Sequence from typing import Any @@ -22,20 +21,13 @@ def _run(self) -> NodeRunResult: system_inputs = self.graph_runtime_state.variable_pool.system_variables for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] + node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - outputs=node_inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: StartNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index d9099a8118498e..e934d69fa3049a 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,3 @@ - - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData): """ Code Node Data. """ + variables: list[VariableSelector] - template: str \ No newline at end of file + template: str diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index b14a394a0a8cfd..2829144ead80bb 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -8,7 +8,7 @@ from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) class TemplateTransformNode(BaseNode): @@ -24,15 +24,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ return { "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } + "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } def _run(self) -> NodeRunResult: @@ -51,38 +43,25 @@ def _run(self) -> NodeRunResult: # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=node_data.template, - inputs=variables + language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) except CodeExecutionException as e: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) + return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters", ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={ - 'output': result['result'] - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -92,5 +71,6 @@ def _extract_variable_selector_to_variable_mapping( :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 5da5cd07271bae..28fbf789fdd686 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -8,46 +8,47 @@ class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api', 'workflow'] - provider_name: str # redundancy + provider_type: Literal["builtin", "api", "workflow"] + provider_name: str # redundancy tool_name: str - tool_label: str # redundancy + tool_label: str # redundancy tool_configurations: dict[str, Any] - @field_validator('tool_configurations', mode='before') + @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value, values: ValidationInfo): if not isinstance(value, dict): - raise ValueError('tool_configurations must be a dictionary') - - for key in values.data.get('tool_configurations', {}).keys(): - value = values.data.get('tool_configurations', {}).get(key) + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}).keys(): + value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): - raise ValueError(f'{key} must be a string') - + raise ValueError(f"{key} must be a string") + return value + class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Literal['mixed', 'variable', 'constant'] + type: Literal["mixed", "variable", "constant"] - @field_validator('type', mode='before') + @field_validator("type", mode="before") @classmethod def check_type(cls, value, validation_info: ValidationInfo): typ = value - value = validation_info.data.get('value') - if typ == 'mixed' and not isinstance(value, str): - raise ValueError('value must be a string') - elif typ == 'variable': + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": if not isinstance(value, list): - raise ValueError('value must be a list') + raise ValueError("value must be a list") for val in value: if not isinstance(val, str): - raise ValueError('value must be a list of strings') - elif typ == 'constant' and not isinstance(value, str | int | float | bool): - raise ValueError('value must be a string, int, float, or bool') + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") return typ """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index feedeb6dad3824..e55adfc1f40272 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -34,10 +34,7 @@ def _run(self) -> NodeRunResult: node_data = cast(ToolNodeData, self.node_data) # fetch tool icon - tool_info = { - 'provider_type': node_data.provider_type, - 'provider_id': node_data.provider_id - } + tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id} # get tool runtime try: @@ -48,16 +45,21 @@ def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to get tool runtime: {str(e)}' + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to get tool runtime: {str(e)}", ) # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters( + tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data, + for_log=True, + ) try: messages = ToolEngine.workflow_invoke( @@ -72,10 +74,8 @@ def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to invoke tool: {str(e)}', + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", ) # convert tool messages @@ -83,15 +83,9 @@ def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'text': plain_text, - 'files': files, - 'json': json - }, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - inputs=parameters_for_log + outputs={"text": plain_text, "files": files, "json": json}, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + inputs=parameters_for_log, ) def _generate_parameters( @@ -123,12 +117,10 @@ def _generate_parameters( result[parameter_name] = None continue if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [ - v.to_dict() for v in self._fetch_files(variable_pool) - ] + result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] else: tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == 'variable': + if tool_input.type == "variable": # TODO: check if the variable exists in the variable pool parameter_value = variable_pool.get(tool_input.value).value else: @@ -142,12 +134,11 @@ def _generate_parameters( return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ - -> tuple[str, list[FileVar], list[dict]]: + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -172,38 +163,44 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): url = response.message ext = path.splitext(url)[1] - mimetype = response.meta.get('mime_type', 'image/jpeg') - filename = response.save_as or url.split('/')[-1] - transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + mimetype = response.meta.get("mime_type", "image/jpeg") + filename = response.save_as or url.split("/")[-1] + transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) # get tool file id - tool_file_id = url.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, - extension=ext, - mime_type=mimetype, - )) + tool_file_id = url.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=url, + related_id=tool_file_id, + filename=filename, + extension=ext, + mime_type=mimetype, + ) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, - extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get('mime_type', 'application/octet-stream'), - )) + tool_file_id = response.message.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=response.save_as, + extension=path.splitext(response.save_as)[1], + mime_type=response.meta.get("mime_type", "application/octet-stream"), + ) + ) elif response.type == ToolInvokeMessage.MessageType.LINK: pass # TODO: @@ -213,21 +210,23 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ Extract tool response text """ - return '\n'.join([ - f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' - for message in tool_response - ]) + return "\n".join( + [ + f"{message.message}" + if message.type == ToolInvokeMessage.MessageType.TEXT + else f"Link: {message.message}" + if message.type == ToolInvokeMessage.MessageType.LINK + else "" + for message in tool_response + ] + ) def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -239,17 +238,15 @@ def _extract_variable_selector_to_variable_mapping( result = {} for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': + if input.type == "mixed": selectors = VariableTemplateParser(input.value).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector - elif input.type == 'variable': + elif input.type == "variable": result[parameter_name] = input.value - elif input.type == 'constant': + elif input.type == "constant": pass - result = { - node_id + '.' + key: value for key, value in result.items() - } + result = {node_id + "." + key: value for key, value in result.items()} return result diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index e5de38dc0faa26..eb893a04e38f75 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ - - from typing import Literal, Optional from pydantic import BaseModel @@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel): """ Advanced setting. """ + group_enabled: bool class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] + + output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] variables: list[list[str]] group_name: str groups: list[Group] + class VariableAssignerNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'variable-assigner' + + type: str = "variable-assigner" output_type: str variables: list[list[str]] advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 6944d9e82dcd41..f03eae257a9a21 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -21,13 +21,9 @@ def _run(self) -> NodeRunResult: for selector in node_data.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs = { - "output": variable - } + outputs = {"output": variable} - inputs = { - '.'.join(selector[1:]): variable - } + inputs = {".".join(selector[1:]): variable} break else: for group in node_data.advanced_settings.groups: @@ -35,24 +31,15 @@ def _run(self) -> NodeRunResult: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs[group.group_name] = { - 'output': variable - } - inputs['.'.join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable} + inputs[".".join(selector[1:])] = variable break - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index d791d51523ec0f..83da4bdc79bb21 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -2,7 +2,7 @@ from .node_data import VariableAssignerData, WriteMode __all__ = [ - 'VariableAssignerNode', - 'VariableAssignerData', - 'WriteMode', + "VariableAssignerNode", + "VariableAssignerData", + "WriteMode", ] diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index b2f32c6aaaea20..3969299795c0c1 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -24,43 +24,43 @@ def _run(self) -> NodeRunResult: # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') + raise VariableAssignerNodeError("assigned variable not found") match data.write_mode: case WriteMode.OVER_WRITE: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) + raise VariableAssignerNodeError("input value not found") + updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') + raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) + updated_variable = original_variable.model_copy(update={"value": updated_value}) case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') + raise VariableAssignerNodeError("conversation_id not found") update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ - 'value': income_value.to_object(), + "value": income_value.to_object(), }, ) @@ -72,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable): with Session(db.engine) as session: row = session.scalar(stmt) if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') + raise VariableAssignerNodeError("conversation variable not found in the database") row.data = variable.model_dump_json() session.commit() @@ -84,8 +84,8 @@ def get_zero_value(t: SegmentType): case SegmentType.OBJECT: return factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment('') + return factory.build_segment("") case SegmentType.NUMBER: return factory.build_segment(0) case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') + raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index b3652b68024265..8ac8eadf7ca5c3 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -6,14 +6,14 @@ class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' + OVER_WRITE = "over-write" + APPEND = "append" + CLEAR = "clear" class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' + title: str = "Variable Assigner" + desc: Optional[str] = "Assign a value to a variable" assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index e195730a31ebb5..b8e8b881a55956 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -7,11 +7,26 @@ class Condition(BaseModel): """ Condition entity """ + variable_selector: list[str] comparison_operator: Literal[ # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", ] value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 5ff61aab3d79ec..395ee824787f8f 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -15,9 +15,7 @@ def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[C index = 0 for condition in conditions: index += 1 - actual_value = variable_pool.get_any( - condition.variable_selector - ) + actual_value = variable_pool.get_any(condition.variable_selector) expected_value = None if condition.value is not None: @@ -25,9 +23,7 @@ def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[C variable_selectors = variable_template_parser.extract_variable_selectors() if variable_selectors: for variable_selector in variable_selectors: - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) expected_value = variable_template_parser.format({variable_selector.variable: value}) if expected_value is None: @@ -40,7 +36,7 @@ def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[C { "actual_value": actual_value, "expected_value": expected_value, - "comparison_operator": comparison_operator + "comparison_operator": comparison_operator, } ) @@ -50,10 +46,10 @@ def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[C return input_conditions, group_result def evaluate_condition( - self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], - comparison_operator: str, - expected_value: Optional[str] = None + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None, ) -> bool: """ Evaluate condition @@ -109,7 +105,7 @@ def _assert_contains(self, actual_value: Optional[str | list], expected_value: s return False if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value not in actual_value: return False @@ -126,7 +122,7 @@ def _assert_not_contains(self, actual_value: Optional[str | list], expected_valu return True if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value in actual_value: return False @@ -143,7 +139,7 @@ def _assert_start_with(self, actual_value: Optional[str], expected_value: str) - return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.startswith(expected_value): return False @@ -160,7 +156,7 @@ def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.endswith(expected_value): return False @@ -177,7 +173,7 @@ def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value != expected_value: return False @@ -194,7 +190,7 @@ def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bo return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value == expected_value: return False @@ -231,7 +227,7 @@ def _assert_equal(self, actual_value: Optional[int | float], expected_value: str return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -253,7 +249,7 @@ def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -275,7 +271,7 @@ def _assert_greater_than(self, actual_value: Optional[int | float], expected_val return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -297,7 +293,7 @@ def _assert_less_than(self, actual_value: Optional[int | float], expected_value: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -308,8 +304,9 @@ def _assert_less_than(self, actual_value: Optional[int | float], expected_value: return False return True - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_greater_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -320,7 +317,7 @@ def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -331,8 +328,9 @@ def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], return False return True - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_less_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert less than or equal :param actual_value: actual value @@ -343,7 +341,7 @@ def _assert_less_than_or_equal(self, actual_value: Optional[int | float], return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index c43fde172c7b7a..fd0e48b862d5f7 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -5,7 +5,7 @@ from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: @@ -20,7 +20,7 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) key_selectors = filter( lambda t: len(t[1]) >= 2, - ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), + ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), ) inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} @@ -29,13 +29,13 @@ def replacer(match): # return original matched string if key not found value = inputs.get(key, match.group(0)) if value is None: - value = '' + value = "" value = str(value) # remove template variables if required - return re.sub(REGEX, r'{\1}', value) + return re.sub(REGEX, r"{\1}", value) result = re.sub(REGEX, replacer, template) - result = re.sub(r'<\|.*?\|>', '', result) + result = re.sub(r"<\|.*?\|>", "", result) return result @@ -101,8 +101,8 @@ def extract_variable_selectors(self) -> list[VariableSelector]: """ variable_selectors = [] for variable_key in self.variable_keys: - remove_hash = variable_key.replace('#', '') - split_result = remove_hash.split('.') + remove_hash = variable_key.replace("#", "") + split_result = remove_hash.split(".") if len(split_result) < 2: continue @@ -127,7 +127,7 @@ def replacer(match): value = inputs.get(key, match.group(0)) # return original matched string if key not found if value is None: - value = '' + value = "" # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -136,7 +136,7 @@ def replacer(match): return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str): @@ -149,4 +149,4 @@ def remove_template_variables(cls, text: str): Returns: The text with template variables removed. """ - return re.sub(REGEX, r'{\1}', text) + return re.sub(REGEX, r"{\1}", text) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a359bd606ea2e2..25021935eeef2b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -33,19 +33,19 @@ class WorkflowEntry: def __init__( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - workflow_type: WorkflowType, - graph_config: Mapping[str, Any], - graph: Graph, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - variable_pool: VariablePool, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None, ) -> None: """ Init workflow entry @@ -65,7 +65,7 @@ def __init__( # check call depth workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state self.graph_engine = GraphEngine( @@ -82,13 +82,13 @@ def __init__( variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, - thread_pool_id=thread_pool_id + thread_pool_id=thread_pool_id, ) def run( - self, - *, - callbacks: Sequence[WorkflowCallback], + self, + *, + callbacks: Sequence[WorkflowCallback], ) -> Generator[GraphEngineEvent, None, None]: """ :param callbacks: workflow callbacks @@ -101,9 +101,7 @@ def run( for event in generator: if callbacks: for callback in callbacks: - callback.on_event( - event=event - ) + callback.on_event(event=event) yield event except GenerateTaskStoppedException: pass @@ -111,20 +109,12 @@ def run( logger.exception("Unknown Error when workflow entry running") if callbacks: for callback in callbacks: - callback.on_event( - event=GraphRunFailedEvent( - error=str(e) - ) - ) + callback.on_event(event=GraphRunFailedEvent(error=str(e))) return @classmethod def single_step_run( - cls, - workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict + cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -137,30 +127,30 @@ def single_step_run( # fetch node info from workflow graph graph = workflow.graph_dict if not graph: - raise ValueError('workflow graph not found') + raise ValueError("workflow graph not found") - nodes = graph.get('nodes') + nodes = graph.get("nodes") if not nodes: - raise ValueError('nodes not found in workflow graph') + raise ValueError("nodes not found in workflow graph") # fetch node config from node id node_config = None for node in nodes: - if node.get('id') == node_id: + if node.get("id") == node_id: node_config = node break if not node_config: - raise ValueError('node id not found in workflow graph') + raise ValueError("node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) if not node_cls: - raise ValueError(f'Node class not found for node type {node_type}') - + raise ValueError(f"Node class not found for node type {node_type}") + # init variable pool variable_pool = VariablePool( system_variables={}, @@ -169,9 +159,7 @@ def single_step_run( ) # init graph - graph = Graph.init( - graph_config=workflow.graph_dict - ) + graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state node_instance: BaseNode = node_cls( @@ -186,21 +174,17 @@ def single_step_run( user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, - call_depth=0 + call_depth=0, ), graph=graph, - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), ) try: # variable selector to variable mapping try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=node_config + graph_config=workflow.graph_dict, config=node_config ) except NotImplementedError: variable_mapping = {} @@ -211,7 +195,7 @@ def single_step_run( variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=node_instance.node_data + node_data=node_instance.node_data, ) # run node @@ -219,10 +203,7 @@ def single_step_run( return node_instance, generator except Exception as e: - raise WorkflowNodeRunFailedError( - node_instance=node_instance, - error=str(e) - ) + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @classmethod def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: @@ -259,21 +240,20 @@ def mapping_user_inputs_to_variable_pool( variable_pool: VariablePool, tenant_id: str, node_type: NodeType, - node_data: BaseNodeData + node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable - node_variable_list = node_variable.split('.') + node_variable_list = node_variable.split(".") if len(node_variable_list) < 1: - raise ValueError(f'Invalid node variable {node_variable}') - - node_variable_key = '.'.join(node_variable_list[1:]) + raise ValueError(f"Invalid node variable {node_variable}") + + node_variable_key = ".".join(node_variable_list[1:]) - if ( - node_variable_key not in user_inputs - and node_variable not in user_inputs - ) and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {node_variable} not found in user inputs.') + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") # fetch variable node id from variable selector variable_node_id = variable_selector[0] @@ -294,16 +274,17 @@ def mapping_user_inputs_to_variable_pool( detail = node_data.vision.configs.detail if node_data.vision.configs else None for item in input_value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + if isinstance(item, dict) and "type" in item and item["type"] == "image": + transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) file = FileVar( tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get("upload_file_id") + if transfer_method == FileTransferMethod.LOCAL_FILE + else None, + extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), ) new_value.append(file) diff --git a/api/pyproject.toml b/api/pyproject.toml index 69d1fc4ee02743..16b8c32c2a4a5f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -68,7 +68,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "core/**/*.py", "models/**/*.py", "migrations/**/*", ] From d1605952b0fb21c9a010b96ff8d61421f0c2f205 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 10 Sep 2024 17:01:32 +0800 Subject: [PATCH 08/13] fix: input chat input wrong padding (#8207) --- .../debug/debug-with-multiple-model/chat-item.tsx | 1 + .../debug/debug-with-multiple-model/index.tsx | 1 + .../configuration/debug/debug-with-single-model/index.tsx | 1 + web/app/components/base/chat/chat/chat-input.tsx | 7 +++++-- web/app/components/base/chat/chat/index.tsx | 5 ++++- .../workflow/panel/debug-and-preview/chat-wrapper.tsx | 1 + 6 files changed, 13 insertions(+), 3 deletions(-) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index d4a156032465f5..80dfb5c534ea5c 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -129,6 +129,7 @@ const ChatItem: FC = ({ questionIcon={} allToolIcons={allToolIcons} hideLogModal + noSpacing /> ) } diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index 892d0cfe8b3298..3d2f3bca590042 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -128,6 +128,7 @@ const DebugWithMultipleModel = () => { onSend={handleSend} speechToTextConfig={speechToTextConfig} visionConfig={visionConfig} + noSpacing />
) diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index 0f09a232425e54..d93ad00659e478 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -130,6 +130,7 @@ const DebugWithSingleModel = forwardRef ) }) diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index c4578fab62b011..a90f86402d110a 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -32,18 +32,21 @@ import { useDraggableUploader, useImageFiles, } from '@/app/components/base/image-uploader/hooks' +import cn from '@/utils/classnames' type ChatInputProps = { visionConfig?: VisionConfig speechToTextConfig?: EnableType onSend?: OnSend theme?: Theme | null + noSpacing?: boolean } const ChatInput: FC = ({ visionConfig, speechToTextConfig, onSend, theme, + noSpacing, }) => { const { appData } = useChatWithHistoryContext() const { t } = useTranslation() @@ -146,7 +149,7 @@ const ChatInput: FC = ({ return ( <> -
+
= ({ onDrop={onDrop} autoSize /> -
+
{query.trim().length}
diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 7540cd873b9460..65e49eff67dd5f 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -60,6 +60,7 @@ export type ChatProps = { hideProcessDetail?: boolean hideLogModal?: boolean themeBuilder?: ThemeBuilder + noSpacing?: boolean } const Chat: FC = ({ @@ -89,6 +90,7 @@ const Chat: FC = ({ hideProcessDetail, hideLogModal, themeBuilder, + noSpacing, }) => { const { t } = useTranslation() const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ @@ -197,7 +199,7 @@ const Chat: FC = ({ {chatNode}
{ chatList.map((item, index) => { @@ -268,6 +270,7 @@ const Chat: FC = ({ speechToTextConfig={config?.speech_to_text} onSend={onSend} theme={themeBuilder?.theme} + noSpacing={noSpacing} /> ) } diff --git a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx index e7639ddc794083..a7dd607e221ea5 100644 --- a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx @@ -118,6 +118,7 @@ const ChatWrapper = forwardRef(({ showConv } )} + noSpacing suggestedQuestions={suggestedQuestions} showPromptLog chatAnswerContainerInner='!pr-2' From d109881410e6630c582f4eeb812e53b144e6189b Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 10 Sep 2024 17:08:06 +0800 Subject: [PATCH 09/13] chore(api/models): apply ruff reformatting (#7600) Co-authored-by: -LAN- --- api/models/__init__.py | 10 +- api/models/account.py | 169 +++--- api/models/api_based_extension.py | 18 +- api/models/dataset.py | 581 ++++++++++---------- api/models/model.py | 883 ++++++++++++++++-------------- api/models/provider.py | 122 +++-- api/models/source.py | 48 +- api/models/task.py | 21 +- api/models/tool.py | 20 +- api/models/tools.py | 142 ++--- api/models/types.py | 6 +- api/models/web.py | 21 +- api/models/workflow.py | 357 ++++++------ api/pyproject.toml | 1 - 14 files changed, 1263 insertions(+), 1136 deletions(-) diff --git a/api/models/__init__.py b/api/models/__init__.py index 4012611471c337..30ceef057e28af 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -4,7 +4,7 @@ from .types import StringUUID from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] +__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] class CreatedByRole(Enum): @@ -12,11 +12,11 @@ class CreatedByRole(Enum): Enum class for createdByRole """ - ACCOUNT = 'account' - END_USER = 'end_user' + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -26,4 +26,4 @@ def value_of(cls, value: str) -> 'CreatedByRole': for role in cls: if role.value == value: return role - raise ValueError(f'invalid createdByRole value {value}') + raise ValueError(f"invalid createdByRole value {value}") diff --git a/api/models/account.py b/api/models/account.py index 67d940b7b7190e..60b4f11aad2bdc 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -9,21 +9,18 @@ class AccountStatus(str, enum.Enum): - PENDING = 'pending' - UNINITIALIZED = 'uninitialized' - ACTIVE = 'active' - BANNED = 'banned' - CLOSED = 'closed' + PENDING = "pending" + UNINITIALIZED = "uninitialized" + ACTIVE = "active" + BANNED = "banned" + CLOSED = "closed" class Account(UserMixin, db.Model): - __tablename__ = 'accounts' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_pkey'), - db.Index('account_email_idx', 'email') - ) + __tablename__ = "accounts" + __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -34,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def is_password_set(self): @@ -65,11 +62,13 @@ def current_tenant_id(self): @current_tenant_id.setter def current_tenant_id(self, value: str): try: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == value) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.account_id == self.id) \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == value) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == self.id) .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join @@ -91,20 +90,18 @@ def get_status(self) -> AccountStatus: @classmethod def get_by_openid(cls, provider: str, open_id: str) -> db.Model: - account_integrate = db.session.query(AccountIntegrate). \ - filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ - one_or_none() + account_integrate = ( + db.session.query(AccountIntegrate) + .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .one_or_none() + ) if account_integrate: - return db.session.query(Account). \ - filter(Account.id == account_integrate.account_id). \ - one_or_none() + return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None def get_integrates(self) -> list[db.Model]: ai = db.Model - return db.session.query(ai).filter( - ai.account_id == self.id - ).all() + return db.session.query(ai).filter(ai.account_id == self.id).all() # check current_user.current_tenant.current_role in ['admin', 'owner'] @property @@ -123,61 +120,75 @@ def is_dataset_editor(self): def is_dataset_operator(self): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + class TenantStatus(str, enum.Enum): - NORMAL = 'normal' - ARCHIVE = 'archive' + NORMAL = "normal" + ARCHIVE = "archive" class TenantAccountRole(str, enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - EDITOR = 'editor' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" @staticmethod def is_valid_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } @staticmethod def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} - + @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR} - + return role and role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + @staticmethod def is_editing_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + class Tenant(db.Model): - __tablename__ = 'tenants' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_pkey'), - ) + __tablename__ = "tenants" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def get_accounts(self) -> list[Account]: - return db.session.query(Account).filter( - Account.id == TenantAccountJoin.account_id, - TenantAccountJoin.tenant_id == self.id - ).all() + return ( + db.session.query(Account) + .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .all() + ) @property def custom_config_dict(self) -> dict: @@ -189,54 +200,54 @@ def custom_config_dict(self, value: dict): class TenantAccountJoinRole(enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" class TenantAccountJoin(db.Model): - __tablename__ = 'tenant_account_joins' + __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), - db.Index('tenant_account_join_account_id_idx', 'account_id'), - db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), - db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + db.Index("tenant_account_join_account_id_idx", "account_id"), + db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - role = db.Column(db.String(16), nullable=False, server_default='normal') + current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AccountIntegrate(db.Model): - __tablename__ = 'account_integrates' + __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), - db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), - db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class InvitationCode(db.Model): - __tablename__ = 'invitation_codes' + __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), - db.Index('invitation_codes_batch_idx', 'batch'), - db.Index('invitation_codes_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + db.Index("invitation_codes_batch_idx", "batch"), + db.Index("invitation_codes_code_idx", "code", "status"), ) id = db.Column(db.Integer, nullable=False) @@ -247,4 +258,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 7f69323628a7cc..97173747afc4b1 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -6,22 +6,22 @@ class APIBasedExtensionPoint(enum.Enum): - APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' - PING = 'ping' - APP_MODERATION_INPUT = 'app.moderation.input' - APP_MODERATION_OUTPUT = 'app.moderation.output' + APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" + PING = "ping" + APP_MODERATION_INPUT = "app.moderation.input" + APP_MODERATION_OUTPUT = "app.moderation.output" class APIBasedExtension(db.Model): - __tablename__ = 'api_based_extensions' + __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), - db.Index('api_based_extension_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index bf3f12a2c56e15..55f6ed3180b9c4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -24,37 +24,34 @@ class DatasetPermissionEnum(str, enum.Enum): - ONLY_ME = 'only_me' - ALL_TEAM = 'all_team_members' - PARTIAL_TEAM = 'partial_members' + ONLY_ME = "only_me" + ALL_TEAM = "all_team_members" + PARTIAL_TEAM = "partial_members" + class Dataset(db.Model): - __tablename__ = 'datasets' + __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_pkey'), - db.Index('dataset_tenant_idx', 'tenant_id'), - db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="dataset_pkey"), + db.Index("dataset_tenant_idx", "tenant_id"), + db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, - server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, - server_default=db.text("'only_me'::character varying")) + provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -62,8 +59,9 @@ class Dataset(db.Model): @property def dataset_keyword_table(self): - dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( - DatasetKeywordTable.dataset_id == self.id).first() + dataset_keyword_table = ( + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + ) if dataset_keyword_table: return dataset_keyword_table @@ -79,13 +77,19 @@ def created_by_account(self): @property def latest_process_rule(self): - return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ - .order_by(DatasetProcessRule.created_at.desc()).first() + return ( + DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + .order_by(DatasetProcessRule.created_at.desc()) + .first() + ) @property def app_count(self): - return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, - App.id == AppDatasetJoin.app_id).scalar() + return ( + db.session.query(func.count(AppDatasetJoin.id)) + .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .scalar() + ) @property def document_count(self): @@ -93,30 +97,40 @@ def document_count(self): @property def available_document_count(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def available_segment_count(self): - return db.session.query(func.count(DocumentSegment.id)).filter( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).scalar() + return ( + db.session.query(func.count(DocumentSegment.id)) + .filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .scalar() + ) @property def word_count(self): - return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ - .filter(Document.dataset_id == self.id).scalar() + return ( + Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + .filter(Document.dataset_id == self.id) + .scalar() + ) @property def doc_form(self): - document = db.session.query(Document).filter( - Document.dataset_id == self.id).first() + document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -124,76 +138,68 @@ def doc_form(self): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } return self.retrieval_model if self.retrieval_model else default_retrieval_model @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'knowledge' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "knowledge", + ) + .all() + ) return tags if tags else [] @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f'Vector_index_{normalized_dataset_id}_Node' + return f"Vector_index_{normalized_dataset_id}_Node" class DatasetProcessRule(db.Model): - __tablename__ = 'dataset_process_rules' + __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), - db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, - server_default=db.text("'automatic'::character varying")) + mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - MODES = ['automatic', 'custom'] - PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] + MODES = ["automatic", "custom"] + PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES = { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } def to_dict(self): return { - 'id': self.id, - 'dataset_id': self.dataset_id, - 'mode': self.mode, - 'rules': self.rules_dict, - 'created_by': self.created_by, - 'created_at': self.created_at, + "id": self.id, + "dataset_id": self.dataset_id, + "mode": self.mode, + "rules": self.rules_dict, + "created_by": self.created_by, + "created_at": self.created_at, } @property @@ -205,17 +211,16 @@ def rules_dict(self): class Document(db.Model): - __tablename__ = 'documents' + __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_pkey'), - db.Index('document_dataset_id_idx', 'dataset_id'), - db.Index('document_is_paused_idx', 'is_paused'), - db.Index('document_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_pkey"), + db.Index("document_dataset_id_idx", "dataset_id"), + db.Index("document_is_paused_idx", "is_paused"), + db.Index("document_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) @@ -227,8 +232,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -250,7 +254,7 @@ class Document(db.Model): completed_at = db.Column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) @@ -259,44 +263,39 @@ class Document(db.Model): stopped_at = db.Column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String( - 255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, - server_default=db.text('false')) + archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - doc_form = db.Column(db.String( - 255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] + DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @property def display_status(self): status = None - if self.indexing_status == 'waiting': - status = 'queuing' - elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: - status = 'paused' - elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: - status = 'indexing' - elif self.indexing_status == 'error': - status = 'error' - elif self.indexing_status == 'completed' and not self.archived and self.enabled: - status = 'available' - elif self.indexing_status == 'completed' and not self.archived and not self.enabled: - status = 'disabled' - elif self.indexing_status == 'completed' and self.archived: - status = 'archived' + if self.indexing_status == "waiting": + status = "queuing" + elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: + status = "paused" + elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + status = "indexing" + elif self.indexing_status == "error": + status = "error" + elif self.indexing_status == "completed" and not self.archived and self.enabled: + status = "available" + elif self.indexing_status == "completed" and not self.archived and not self.enabled: + status = "disabled" + elif self.indexing_status == "completed" and self.archived: + status = "archived" return status @property @@ -313,24 +312,26 @@ def data_source_info_dict(self): @property def data_source_detail_dict(self): if self.data_source_info: - if self.data_source_type == 'upload_file': + if self.data_source_type == "upload_file": data_source_info_dict = json.loads(self.data_source_info) - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .one_or_none() + ) if file_detail: return { - 'upload_file': { - 'id': file_detail.id, - 'name': file_detail.name, - 'size': file_detail.size, - 'extension': file_detail.extension, - 'mime_type': file_detail.mime_type, - 'created_by': file_detail.created_by, - 'created_at': file_detail.created_at.timestamp() + "upload_file": { + "id": file_detail.id, + "name": file_detail.name, + "size": file_detail.size, + "extension": file_detail.extension, + "mime_type": file_detail.mime_type, + "created_by": file_detail.created_by, + "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': + elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": return json.loads(self.data_source_info) return {} @@ -356,120 +357,123 @@ def segment_count(self): @property def hit_count(self): - return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ - .filter(DocumentSegment.document_id == self.id).scalar() + return ( + DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + .filter(DocumentSegment.document_id == self.id) + .scalar() + ) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'dataset_id': self.dataset_id, - 'position': self.position, - 'data_source_type': self.data_source_type, - 'data_source_info': self.data_source_info, - 'dataset_process_rule_id': self.dataset_process_rule_id, - 'batch': self.batch, - 'name': self.name, - 'created_from': self.created_from, - 'created_by': self.created_by, - 'created_api_request_id': self.created_api_request_id, - 'created_at': self.created_at, - 'processing_started_at': self.processing_started_at, - 'file_id': self.file_id, - 'word_count': self.word_count, - 'parsing_completed_at': self.parsing_completed_at, - 'cleaning_completed_at': self.cleaning_completed_at, - 'splitting_completed_at': self.splitting_completed_at, - 'tokens': self.tokens, - 'indexing_latency': self.indexing_latency, - 'completed_at': self.completed_at, - 'is_paused': self.is_paused, - 'paused_by': self.paused_by, - 'paused_at': self.paused_at, - 'error': self.error, - 'stopped_at': self.stopped_at, - 'indexing_status': self.indexing_status, - 'enabled': self.enabled, - 'disabled_at': self.disabled_at, - 'disabled_by': self.disabled_by, - 'archived': self.archived, - 'archived_reason': self.archived_reason, - 'archived_by': self.archived_by, - 'archived_at': self.archived_at, - 'updated_at': self.updated_at, - 'doc_type': self.doc_type, - 'doc_metadata': self.doc_metadata, - 'doc_form': self.doc_form, - 'doc_language': self.doc_language, - 'display_status': self.display_status, - 'data_source_info_dict': self.data_source_info_dict, - 'average_segment_length': self.average_segment_length, - 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - 'dataset': self.dataset.to_dict() if self.dataset else None, - 'segment_count': self.segment_count, - 'hit_count': self.hit_count + "id": self.id, + "tenant_id": self.tenant_id, + "dataset_id": self.dataset_id, + "position": self.position, + "data_source_type": self.data_source_type, + "data_source_info": self.data_source_info, + "dataset_process_rule_id": self.dataset_process_rule_id, + "batch": self.batch, + "name": self.name, + "created_from": self.created_from, + "created_by": self.created_by, + "created_api_request_id": self.created_api_request_id, + "created_at": self.created_at, + "processing_started_at": self.processing_started_at, + "file_id": self.file_id, + "word_count": self.word_count, + "parsing_completed_at": self.parsing_completed_at, + "cleaning_completed_at": self.cleaning_completed_at, + "splitting_completed_at": self.splitting_completed_at, + "tokens": self.tokens, + "indexing_latency": self.indexing_latency, + "completed_at": self.completed_at, + "is_paused": self.is_paused, + "paused_by": self.paused_by, + "paused_at": self.paused_at, + "error": self.error, + "stopped_at": self.stopped_at, + "indexing_status": self.indexing_status, + "enabled": self.enabled, + "disabled_at": self.disabled_at, + "disabled_by": self.disabled_by, + "archived": self.archived, + "archived_reason": self.archived_reason, + "archived_by": self.archived_by, + "archived_at": self.archived_at, + "updated_at": self.updated_at, + "doc_type": self.doc_type, + "doc_metadata": self.doc_metadata, + "doc_form": self.doc_form, + "doc_language": self.doc_language, + "display_status": self.display_status, + "data_source_info_dict": self.data_source_info_dict, + "average_segment_length": self.average_segment_length, + "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + "dataset": self.dataset.to_dict() if self.dataset else None, + "segment_count": self.segment_count, + "hit_count": self.hit_count, } @classmethod def from_dict(cls, data: dict): return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - dataset_id=data.get('dataset_id'), - position=data.get('position'), - data_source_type=data.get('data_source_type'), - data_source_info=data.get('data_source_info'), - dataset_process_rule_id=data.get('dataset_process_rule_id'), - batch=data.get('batch'), - name=data.get('name'), - created_from=data.get('created_from'), - created_by=data.get('created_by'), - created_api_request_id=data.get('created_api_request_id'), - created_at=data.get('created_at'), - processing_started_at=data.get('processing_started_at'), - file_id=data.get('file_id'), - word_count=data.get('word_count'), - parsing_completed_at=data.get('parsing_completed_at'), - cleaning_completed_at=data.get('cleaning_completed_at'), - splitting_completed_at=data.get('splitting_completed_at'), - tokens=data.get('tokens'), - indexing_latency=data.get('indexing_latency'), - completed_at=data.get('completed_at'), - is_paused=data.get('is_paused'), - paused_by=data.get('paused_by'), - paused_at=data.get('paused_at'), - error=data.get('error'), - stopped_at=data.get('stopped_at'), - indexing_status=data.get('indexing_status'), - enabled=data.get('enabled'), - disabled_at=data.get('disabled_at'), - disabled_by=data.get('disabled_by'), - archived=data.get('archived'), - archived_reason=data.get('archived_reason'), - archived_by=data.get('archived_by'), - archived_at=data.get('archived_at'), - updated_at=data.get('updated_at'), - doc_type=data.get('doc_type'), - doc_metadata=data.get('doc_metadata'), - doc_form=data.get('doc_form'), - doc_language=data.get('doc_language') + id=data.get("id"), + tenant_id=data.get("tenant_id"), + dataset_id=data.get("dataset_id"), + position=data.get("position"), + data_source_type=data.get("data_source_type"), + data_source_info=data.get("data_source_info"), + dataset_process_rule_id=data.get("dataset_process_rule_id"), + batch=data.get("batch"), + name=data.get("name"), + created_from=data.get("created_from"), + created_by=data.get("created_by"), + created_api_request_id=data.get("created_api_request_id"), + created_at=data.get("created_at"), + processing_started_at=data.get("processing_started_at"), + file_id=data.get("file_id"), + word_count=data.get("word_count"), + parsing_completed_at=data.get("parsing_completed_at"), + cleaning_completed_at=data.get("cleaning_completed_at"), + splitting_completed_at=data.get("splitting_completed_at"), + tokens=data.get("tokens"), + indexing_latency=data.get("indexing_latency"), + completed_at=data.get("completed_at"), + is_paused=data.get("is_paused"), + paused_by=data.get("paused_by"), + paused_at=data.get("paused_at"), + error=data.get("error"), + stopped_at=data.get("stopped_at"), + indexing_status=data.get("indexing_status"), + enabled=data.get("enabled"), + disabled_at=data.get("disabled_at"), + disabled_by=data.get("disabled_by"), + archived=data.get("archived"), + archived_reason=data.get("archived_reason"), + archived_by=data.get("archived_by"), + archived_at=data.get("archived_at"), + updated_at=data.get("updated_at"), + doc_type=data.get("doc_type"), + doc_metadata=data.get("doc_metadata"), + doc_form=data.get("doc_form"), + doc_language=data.get("doc_language"), ) + class DocumentSegment(db.Model): - __tablename__ = 'document_segments' + __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_segment_pkey'), - db.Index('document_segment_dataset_id_idx', 'dataset_id'), - db.Index('document_segment_document_id_idx', 'document_id'), - db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), - db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), - db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), - db.Index('document_segment_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_segment_pkey"), + db.Index("document_segment_dataset_id_idx", "dataset_id"), + db.Index("document_segment_document_id_idx", "document_id"), + db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -486,18 +490,14 @@ class DocumentSegment(db.Model): # basic fields hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, - server_default=db.text("'waiting'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -513,17 +513,19 @@ def document(self): @property def previous_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position - 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) + .first() + ) @property def next_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position + 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) + .first() + ) def get_sign_content(self): pattern = r"/files/([a-f0-9\-]+)/image-preview" @@ -535,7 +537,7 @@ def get_sign_content(self): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -546,21 +548,20 @@ def get_sign_content(self): # Reconstruct the text with signed URLs offset = 0 for start, end, signed_url in signed_urls: - text = text[:start + offset] + signed_url + text[end + offset:] + text = text[: start + offset] + signed_url + text[end + offset :] offset += len(signed_url) - (end - start) return text - class AppDatasetJoin(db.Model): - __tablename__ = 'app_dataset_joins' + __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), - db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -571,13 +572,13 @@ def app(self): class DatasetQuery(db.Model): - __tablename__ = 'dataset_queries' + __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), - db.Index('dataset_query_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) @@ -588,17 +589,18 @@ class DatasetQuery(db.Model): class DatasetKeywordTable(db.Model): - __tablename__ = 'dataset_keyword_tables' + __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), - db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False, - server_default=db.text("'database'::character varying")) + data_source_type = db.Column( + db.String(255), nullable=False, server_default=db.text("'database'::character varying") + ) @property def keyword_table_dict(self): @@ -614,19 +616,17 @@ def object_hook(self, dct): return dct # get dataset - dataset = Dataset.query.filter_by( - id=self.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=self.dataset_id).first() if not dataset: return None - if self.data_source_type == 'database': + if self.data_source_type == "database": return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None else: - file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' + file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" try: keyword_table_text = storage.load_once(file_key) if keyword_table_text: - return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) + return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) return None except Exception as e: logging.exception(str(e)) @@ -634,21 +634,21 @@ def object_hook(self, dct): class Embedding(db.Model): - __tablename__ = 'embeddings' + __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), - db.Index('created_at_idx', 'created_at') + db.PrimaryKeyConstraint("id", name="embedding_pkey"), + db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - model_name = db.Column(db.String(255), nullable=False, - server_default=db.text("'text-embedding-ada-002'::character varying")) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = db.Column( + db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - provider_name = db.Column(db.String(255), nullable=False, - server_default=db.text("''::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -658,33 +658,32 @@ def get_embedding(self) -> list[float]: class DatasetCollectionBinding(db.Model): - __tablename__ = 'dataset_collection_bindings' + __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), - db.Index('provider_model_name_idx', 'provider_name', 'model_name') - + db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetPermission(db.Model): - __tablename__ = 'dataset_permissions' + __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'), - db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'), - db.Index('idx_dataset_permissions_account_id', 'account_id'), - db.Index('idx_dataset_permissions_tenant_id', 'tenant_id') + db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + db.Index("idx_dataset_permissions_account_id", "account_id"), + db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index e81d25fbc9daea..8ab3026522faa5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -20,25 +20,23 @@ class DifySetup(db.Model): - __tablename__ = 'dify_setups' - __table_args__ = ( - db.PrimaryKeyConstraint('version', name='dify_setup_pkey'), - ) + __tablename__ = "dify_setups" + __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AppMode(Enum): - COMPLETION = 'completion' - WORKFLOW = 'workflow' - CHAT = 'chat' - ADVANCED_CHAT = 'advanced-chat' - AGENT_CHAT = 'agent-chat' - CHANNEL = 'channel' + COMPLETION = "completion" + WORKFLOW = "workflow" + CHAT = "chat" + ADVANCED_CHAT = "advanced-chat" + AGENT_CHAT = "agent-chat" + CHANNEL = "channel" @classmethod - def value_of(cls, value: str) -> 'AppMode': + def value_of(cls, value: str) -> "AppMode": """ Get value of given mode. @@ -48,21 +46,19 @@ def value_of(cls, value: str) -> 'AppMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class IconType(Enum): IMAGE = "image" EMOJI = "emoji" + class App(db.Model): - __tablename__ = 'apps' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_pkey'), - db.Index('app_tenant_id_idx', 'tenant_id') - ) + __tablename__ = "apps" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) @@ -75,17 +71,17 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -97,7 +93,7 @@ def desc_or_prompt(self): if app_model_config: return app_model_config.pre_prompt else: - return '' + return "" @property def site(self): @@ -105,24 +101,24 @@ def site(self): return site @property - def app_model_config(self) -> Optional['AppModelConfig']: + def app_model_config(self) -> Optional["AppModelConfig"]: if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional['Workflow']: + def workflow(self) -> Optional["Workflow"]: if self.workflow_id: from .workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @property def api_base_url(self): - return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1" @property def tenant(self): @@ -136,8 +132,9 @@ def is_agent(self) -> bool: return False if not app_model_config.agent_mode: return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + "strategy", "" + ) in ["function_call", "react"]: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -159,16 +156,16 @@ def deleted_tools(self) -> list: if not app_model_config.agent_mode: return [] agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) + tools = agent_mode.get("tools", []) provider_ids = [] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api': + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api": # check if provider id is a uuid string, if not, skip try: uuid.UUID(provider_id) @@ -180,8 +177,7 @@ def deleted_tools(self) -> list: return [] api_providers = db.session.execute( - text('SELECT id FROM tool_api_providers WHERE id IN :provider_ids'), - {'provider_ids': tuple(provider_ids)} + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} ).fetchall() deleted_tools = [] @@ -190,44 +186,43 @@ def deleted_tools(self) -> list: for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api' and provider_id not in current_api_provider_ids: - deleted_tools.append(tool['tool_name']) + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api" and provider_id not in current_api_provider_ids: + deleted_tools.append(tool["tool_name"]) return deleted_tools @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'app' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "app", + ) + .all() + ) return tags if tags else [] class AppModelConfig(db.Model): - __tablename__ = 'app_model_configs' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), - db.Index('app_app_id_idx', 'app_id') - ) + __tablename__ = "app_model_configs" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -263,28 +258,29 @@ def suggested_questions_list(self) -> list: @property def suggested_questions_after_answer_dict(self) -> dict: - return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ + return ( + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer else {"enabled": False} + ) @property def speech_to_text_dict(self) -> dict: - return json.loads(self.speech_to_text) if self.speech_to_text \ - else {"enabled": False} + return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property def text_to_speech_dict(self) -> dict: - return json.loads(self.text_to_speech) if self.text_to_speech \ - else {"enabled": False} + return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property def retriever_resource_dict(self) -> dict: - return json.loads(self.retriever_resource) if self.retriever_resource \ - else {"enabled": True} + return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property def annotation_reply_dict(self) -> dict: - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -293,8 +289,8 @@ def annotation_reply_dict(self) -> dict: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } else: @@ -306,13 +302,15 @@ def more_like_this_dict(self) -> dict: @property def sensitive_word_avoidance_dict(self) -> dict: - return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ + return ( + json.loads(self.sensitive_word_avoidance) + if self.sensitive_word_avoidance else {"enabled": False, "type": "", "configs": []} + ) @property def external_data_tools_list(self) -> list[dict]: - return json.loads(self.external_data_tools) if self.external_data_tools \ - else [] + return json.loads(self.external_data_tools) if self.external_data_tools else [] @property def user_input_form_list(self) -> dict: @@ -320,8 +318,11 @@ def user_input_form_list(self) -> dict: @property def agent_mode_dict(self) -> dict: - return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], - "prompt": None} + return ( + json.loads(self.agent_mode) + if self.agent_mode + else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + ) @property def chat_prompt_config_dict(self) -> dict: @@ -335,19 +336,28 @@ def completion_prompt_config_dict(self) -> dict: def dataset_configs_dict(self) -> dict: if self.dataset_configs: dataset_configs = json.loads(self.dataset_configs) - if 'retrieval_model' not in dataset_configs: - return {'retrieval_model': 'single'} + if "retrieval_model" not in dataset_configs: + return {"retrieval_model": "single"} else: return dataset_configs return { - 'retrieval_model': 'multiple', - } + "retrieval_model": "multiple", + } @property def file_upload_dict(self) -> dict: - return json.loads(self.file_upload) if self.file_upload else { - "image": {"enabled": False, "number_limits": 3, "detail": "high", - "transfer_methods": ["remote_url", "local_file"]}} + return ( + json.loads(self.file_upload) + if self.file_upload + else { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + } + ) def to_dict(self) -> dict: return { @@ -370,44 +380,53 @@ def to_dict(self) -> dict: "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, - "file_upload": self.file_upload_dict + "file_upload": self.file_upload_dict, } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config.get('opening_statement') - self.suggested_questions = json.dumps(model_config['suggested_questions']) \ - if model_config.get('suggested_questions') else None - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ - if model_config.get('suggested_questions_after_answer') else None - self.speech_to_text = json.dumps(model_config['speech_to_text']) \ - if model_config.get('speech_to_text') else None - self.text_to_speech = json.dumps(model_config['text_to_speech']) \ - if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) \ - if model_config.get('more_like_this') else None - self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ - if model_config.get('sensitive_word_avoidance') else None - self.external_data_tools = json.dumps(model_config['external_data_tools']) \ - if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) \ - if model_config.get('model') else None - self.user_input_form = json.dumps(model_config['user_input_form']) \ - if model_config.get('user_input_form') else None - self.dataset_query_variable = model_config.get('dataset_query_variable') - self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) \ - if model_config.get('agent_mode') else None - self.retriever_resource = json.dumps(model_config['retriever_resource']) \ - if model_config.get('retriever_resource') else None - self.prompt_type = model_config.get('prompt_type', 'simple') - self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ - if model_config.get('chat_prompt_config') else None - self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ - if model_config.get('completion_prompt_config') else None - self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ - if model_config.get('dataset_configs') else None - self.file_upload = json.dumps(model_config.get('file_upload')) \ - if model_config.get('file_upload') else None + self.opening_statement = model_config.get("opening_statement") + self.suggested_questions = ( + json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + ) + self.suggested_questions_after_answer = ( + json.dumps(model_config["suggested_questions_after_answer"]) + if model_config.get("suggested_questions_after_answer") + else None + ) + self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None + self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None + self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.sensitive_word_avoidance = ( + json.dumps(model_config["sensitive_word_avoidance"]) + if model_config.get("sensitive_word_avoidance") + else None + ) + self.external_data_tools = ( + json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + ) + self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.user_input_form = ( + json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + ) + self.dataset_query_variable = model_config.get("dataset_query_variable") + self.pre_prompt = model_config["pre_prompt"] + self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.retriever_resource = ( + json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + ) + self.prompt_type = model_config.get("prompt_type", "simple") + self.chat_prompt_config = ( + json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None + ) + self.completion_prompt_config = ( + json.dumps(model_config.get("completion_prompt_config")) + if model_config.get("completion_prompt_config") + else None + ) + self.dataset_configs = ( + json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None + ) + self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None return self def copy(self): @@ -432,21 +451,21 @@ def copy(self): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, ) return new_app_model_config class RecommendedApp(db.Model): - __tablename__ = 'recommended_apps' + __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='recommended_app_pkey'), - db.Index('recommended_app_app_id_idx', 'app_id'), - db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') + db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + db.Index("recommended_app_app_id_idx", "app_id"), + db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) @@ -457,8 +476,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -467,22 +486,22 @@ def app(self): class InstalledApp(db.Model): - __tablename__ = 'installed_apps' + __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='installed_app_pkey'), - db.Index('installed_app_tenant_id_idx', 'tenant_id'), - db.Index('installed_app_app_id_idx', 'app_id'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + db.PrimaryKeyConstraint("id", name="installed_app_pkey"), + db.Index("installed_app_tenant_id_idx", "tenant_id"), + db.Index("installed_app_app_id_idx", "app_id"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -496,13 +515,13 @@ def tenant(self): class Conversation(db.Model): - __tablename__ = 'conversations' + __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='conversation_pkey'), - db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') + db.PrimaryKeyConstraint("id", name="conversation_pkey"), + db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -514,7 +533,7 @@ class Conversation(db.Model): inputs = db.Column(db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) @@ -523,13 +542,15 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") + message_annotations = db.relationship( + "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def model_config(self): @@ -542,20 +563,21 @@ def model_config(self): if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - if 'model' in override_model_configs: + if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) model_config = app_model_config.to_dict() else: - model_config['configs'] = override_model_configs + model_config["configs"] = override_model_configs else: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + ) model_config = app_model_config.to_dict() - model_config['model_id'] = self.model_id - model_config['provider'] = self.model_provider + model_config["model_id"] = self.model_id + model_config["provider"] = self.model_provider return model_config @@ -568,7 +590,7 @@ def summary_or_query(self): if first_message: return first_message.query else: - return '' + return "" @property def annotated(self): @@ -584,31 +606,51 @@ def message_count(self): @property def user_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def admin_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def first_message(self): @@ -642,33 +684,33 @@ def in_debug_mode(self): class Message(db.Model): - __tablename__ = 'messages' + __tablename__ = "messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_pkey'), - db.Index('message_app_id_idx', 'app_id', 'created_at'), - db.Index('message_conversation_id_idx', 'conversation_id'), - db.Index('message_end_user_idx', 'app_id', 'from_source', 'from_end_user_id'), - db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), - db.Index('message_workflow_run_id_idx', 'conversation_id', 'workflow_run_id') + db.PrimaryKeyConstraint("id", name="message_pkey"), + db.Index("message_app_id_idx", "app_id", "created_at"), + db.Index("message_conversation_id_idx", "conversation_id"), + db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), + db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) inputs = db.Column(db.JSON) query = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) - message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) answer = db.Column(db.Text, nullable=False) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) @@ -678,9 +720,9 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @property @@ -688,7 +730,7 @@ def re_sign_file_url_answer(self) -> str: if not self.answer: return self.answer - pattern = r'\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)' + pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" matches = re.findall(pattern, self.answer) if not matches: @@ -704,9 +746,9 @@ def re_sign_file_url_answer(self) -> str: re_sign_file_url_answer = self.answer for url in urls: - if 'files/tools' in url: + if "files/tools" in url: # get tool file id - tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp=' + tool_file_id_pattern = r"\/files\/tools\/([\.\w-]+)?\?timestamp=" result = re.search(tool_file_id_pattern, url) if not result: continue @@ -714,25 +756,24 @@ def re_sign_file_url_answer(self) -> str: tool_file_id = result.group(1) # get extension - if '.' in tool_file_id: - split_result = tool_file_id.split('.') - extension = f'.{split_result[-1]}' + if "." in tool_file_id: + split_result = tool_file_id.split(".") + extension = f".{split_result[-1]}" if len(extension) > 10: - extension = '.bin' + extension = ".bin" tool_file_id = split_result[0] else: - extension = '.bin' + extension = ".bin" if not tool_file_id: continue sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, - extension=extension + tool_file_id=tool_file_id, extension=extension ) else: # get upload file id - upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp=' + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue @@ -750,14 +791,20 @@ def re_sign_file_url_answer(self) -> str: @property def user_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'user').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .first() + ) return feedback @property def admin_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'admin').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .first() + ) return feedback @property @@ -772,11 +819,15 @@ def annotation(self): @property def annotation_hit_history(self): - annotation_history = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.message_id == self.id).first()) + annotation_history = ( + db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + ) if annotation_history: - annotation = (db.session.query(MessageAnnotation). - filter(MessageAnnotation.id == annotation_history.annotation_id).first()) + annotation = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id == annotation_history.annotation_id) + .first() + ) return annotation return None @@ -784,8 +835,9 @@ def annotation_hit_history(self): def app_model_config(self): conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() if conversation: - return db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id).first() + return ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() + ) return None @@ -799,13 +851,21 @@ def message_metadata_dict(self) -> dict: @property def agent_thoughts(self): - return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ - .order_by(MessageAgentThought.position.asc()).all() + return ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == self.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) @property def retriever_resources(self): - return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ - .order_by(DatasetRetrieverResource.position.asc()).all() + return ( + db.session.query(DatasetRetrieverResource) + .filter(DatasetRetrieverResource.message_id == self.id) + .order_by(DatasetRetrieverResource.position.asc()) + .all() + ) @property def message_files(self): @@ -818,39 +878,39 @@ def files(self): files = [] for message_file in message_files: url = message_file.url - if message_file.type == 'image': - if message_file.transfer_method == 'local_file': - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == message_file.upload_file_id - ).first()) - - url = UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=True + if message_file.type == "image": + if message_file.transfer_method == "local_file": + upload_file = ( + db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first() ) - if message_file.transfer_method == 'tool_file': + + url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True) + if message_file.transfer_method == "tool_file": # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension) + url = ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=tool_file_id, extension=extension + ) - files.append({ - 'id': message_file.id, - 'type': message_file.type, - 'url': url, - 'belongs_to': message_file.belongs_to if message_file.belongs_to else 'user' - }) + files.append( + { + "id": message_file.id, + "type": message_file.type, + "url": url, + "belongs_to": message_file.belongs_to if message_file.belongs_to else "user", + } + ) return files @@ -858,64 +918,65 @@ def files(self): def workflow_run(self): if self.workflow_run_id: from .workflow import WorkflowRun + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None def to_dict(self) -> dict: return { - 'id': self.id, - 'app_id': self.app_id, - 'conversation_id': self.conversation_id, - 'inputs': self.inputs, - 'query': self.query, - 'message': self.message, - 'answer': self.answer, - 'status': self.status, - 'error': self.error, - 'message_metadata': self.message_metadata_dict, - 'from_source': self.from_source, - 'from_end_user_id': self.from_end_user_id, - 'from_account_id': self.from_account_id, - 'created_at': self.created_at.isoformat(), - 'updated_at': self.updated_at.isoformat(), - 'agent_based': self.agent_based, - 'workflow_run_id': self.workflow_run_id + "id": self.id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "inputs": self.inputs, + "query": self.query, + "message": self.message, + "answer": self.answer, + "status": self.status, + "error": self.error, + "message_metadata": self.message_metadata_dict, + "from_source": self.from_source, + "from_end_user_id": self.from_end_user_id, + "from_account_id": self.from_account_id, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "agent_based": self.agent_based, + "workflow_run_id": self.workflow_run_id, } @classmethod def from_dict(cls, data: dict): return cls( - id=data['id'], - app_id=data['app_id'], - conversation_id=data['conversation_id'], - inputs=data['inputs'], - query=data['query'], - message=data['message'], - answer=data['answer'], - status=data['status'], - error=data['error'], - message_metadata=json.dumps(data['message_metadata']), - from_source=data['from_source'], - from_end_user_id=data['from_end_user_id'], - from_account_id=data['from_account_id'], - created_at=data['created_at'], - updated_at=data['updated_at'], - agent_based=data['agent_based'], - workflow_run_id=data['workflow_run_id'] + id=data["id"], + app_id=data["app_id"], + conversation_id=data["conversation_id"], + inputs=data["inputs"], + query=data["query"], + message=data["message"], + answer=data["answer"], + status=data["status"], + error=data["error"], + message_metadata=json.dumps(data["message_metadata"]), + from_source=data["from_source"], + from_end_user_id=data["from_end_user_id"], + from_account_id=data["from_account_id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + agent_based=data["agent_based"], + workflow_run_id=data["workflow_run_id"], ) class MessageFeedback(db.Model): - __tablename__ = 'message_feedbacks' + __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_feedback_pkey'), - db.Index('message_feedback_app_idx', 'app_id'), - db.Index('message_feedback_message_idx', 'message_id', 'from_source'), - db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') + db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + db.Index("message_feedback_app_idx", "app_id"), + db.Index("message_feedback_message_idx", "message_id", "from_source"), + db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -924,8 +985,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def from_account(self): @@ -934,14 +995,14 @@ def from_account(self): class MessageFile(db.Model): - __tablename__ = 'message_files' + __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_file_pkey'), - db.Index('message_file_message_idx', 'message_id'), - db.Index('message_file_created_by_idx', 'created_by') + db.PrimaryKeyConstraint("id", name="message_file_pkey"), + db.Index("message_file_message_idx", "message_id"), + db.Index("message_file_created_by_idx", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False) @@ -950,28 +1011,28 @@ class MessageFile(db.Model): upload_file_id = db.Column(StringUUID, nullable=True) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageAnnotation(db.Model): - __tablename__ = 'message_annotations' + __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), - db.Index('message_annotation_app_idx', 'app_id'), - db.Index('message_annotation_conversation_idx', 'conversation_id'), - db.Index('message_annotation_message_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + db.Index("message_annotation_app_idx", "app_id"), + db.Index("message_annotation_conversation_idx", "conversation_id"), + db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def account(self): @@ -985,32 +1046,35 @@ def annotation_create_account(self): class AppAnnotationHitHistory(db.Model): - __tablename__ = 'app_annotation_hit_histories' + __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), - db.Index('app_annotation_hit_histories_app_idx', 'app_id'), - db.Index('app_annotation_hit_histories_account_idx', 'account_id'), - db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), - db.Index('app_annotation_hit_histories_message_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + db.Index("app_annotation_hit_histories_app_idx", "app_id"), + db.Index("app_annotation_hit_histories_account_idx", "account_id"), + db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - score = db.Column(Float, nullable=False, server_default=db.text('0')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False) @property def account(self): - account = (db.session.query(Account) - .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id) + .first() + ) return account @property @@ -1020,89 +1084,99 @@ def annotation_create_account(self): class AppAnnotationSetting(db.Model): - __tablename__ = 'app_annotation_settings' + __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), - db.Index('app_annotation_settings_app_idx', 'app_id') + db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) + score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def created_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def updated_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = (db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) + + collection_binding_detail = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .first() + ) return collection_binding_detail class OperationLog(db.Model): - __tablename__ = 'operation_logs' + __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='operation_log_pkey'), - db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') + db.PrimaryKeyConstraint("id", name="operation_log_pkey"), + db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class EndUser(UserMixin, db.Model): - __tablename__ = 'end_users' + __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='end_user_pkey'), - db.Index('end_user_session_id_idx', 'session_id', 'type'), - db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), + db.PrimaryKeyConstraint("id", name="end_user_pkey"), + db.Index("end_user_session_id_idx", "session_id", "type"), + db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class Site(db.Model): - __tablename__ = 'sites' + __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='site_pkey'), - db.Index('site_app_id_idx', 'app_id'), - db.Index('site_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="site_pkey"), + db.Index("site_app_id_idx", "app_id"), + db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) icon_type = db.Column(db.String(255), nullable=True) @@ -1111,20 +1185,20 @@ class Site(db.Model): description = db.Column(db.Text) default_language = db.Column(db.String(255), nullable=False) chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) custom_disclaimer = db.Column(db.String(255), nullable=True) customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) @staticmethod @@ -1138,26 +1212,25 @@ def generate_code(n): @property def app_base_url(self): - return ( - dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip('/')) + return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/") class ApiToken(db.Model): - __tablename__ = 'api_tokens' + __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_token_pkey'), - db.Index('api_token_app_id_type_idx', 'app_id', 'type'), - db.Index('api_token_token_idx', 'token', 'type'), - db.Index('api_token_tenant_idx', 'tenant_id', 'type') + db.PrimaryKeyConstraint("id", name="api_token_pkey"), + db.Index("api_token_app_id_type_idx", "app_id", "type"), + db.Index("api_token_token_idx", "token", "type"), + db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @staticmethod def generate_api_key(prefix, n): @@ -1170,13 +1243,13 @@ def generate_api_key(prefix, n): class UploadFile(db.Model): - __tablename__ = 'upload_files' + __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='upload_file_pkey'), - db.Index('upload_file_tenant_idx', 'tenant_id') + db.PrimaryKeyConstraint("id", name="upload_file_pkey"), + db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) storage_type = db.Column(db.String(255), nullable=False) key = db.Column(db.String(255), nullable=False) @@ -1186,38 +1259,38 @@ class UploadFile(db.Model): mime_type = db.Column(db.String(255), nullable=True) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + used = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by = db.Column(StringUUID, nullable=True) used_at = db.Column(db.DateTime, nullable=True) hash = db.Column(db.String(255), nullable=True) class ApiRequest(db.Model): - __tablename__ = 'api_requests' + __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_request_pkey'), - db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') + db.PrimaryKeyConstraint("id", name="api_request_pkey"), + db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageChain(db.Model): - __tablename__ = 'message_chains' + __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_chain_pkey'), - db.Index('message_chain_message_id_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_chain_pkey"), + db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) @@ -1226,14 +1299,14 @@ class MessageChain(db.Model): class MessageAgentThought(db.Model): - __tablename__ = 'message_agent_thoughts' + __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_agent_thought_pkey'), - db.Index('message_agent_thought_message_id_idx', 'message_id'), - db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), + db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + db.Index("message_agent_thought_message_id_idx", "message_id"), + db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) @@ -1248,12 +1321,12 @@ class MessageAgentThought(db.Model): message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) currency = db.Column(db.String, nullable=True) @@ -1310,9 +1383,7 @@ def tool_inputs_dict(self) -> dict: result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: return {} @@ -1333,22 +1404,20 @@ def tool_outputs_dict(self) -> dict: result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): - __tablename__ = 'dataset_retriever_resources' + __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), - db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1369,53 +1438,53 @@ class DatasetRetrieverResource(db.Model): class Tag(db.Model): - __tablename__ = 'tags' + __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_pkey'), - db.Index('tag_type_idx', 'type'), - db.Index('tag_name_idx', 'name'), + db.PrimaryKeyConstraint("id", name="tag_pkey"), + db.Index("tag_type_idx", "type"), + db.Index("tag_name_idx", "name"), ) - TAG_TYPE_LIST = ['knowledge', 'app'] + TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TagBinding(db.Model): - __tablename__ = 'tag_bindings' + __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_binding_pkey'), - db.Index('tag_bind_target_id_idx', 'target_id'), - db.Index('tag_bind_tag_id_idx', 'tag_id'), + db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + db.Index("tag_bind_target_id_idx", "target_id"), + db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TraceAppConfig(db.Model): - __tablename__ = 'trace_app_config' + __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tracing_app_config_pkey'), - db.Index('trace_app_config_app_id_idx', 'app_id'), + db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): @@ -1427,11 +1496,11 @@ def tracing_config_str(self): def to_dict(self): return { - 'id': self.id, - 'app_id': self.app_id, - 'tracing_provider': self.tracing_provider, - 'tracing_config': self.tracing_config_dict, + "id": self.id, + "app_id": self.app_id, + "tracing_provider": self.tracing_provider, + "tracing_config": self.tracing_config_dict, "is_active": self.is_active, "created_at": self.created_at.__str__() if self.created_at else None, - 'updated_at': self.updated_at.__str__() if self.updated_at else None, + "updated_at": self.updated_at.__str__() if self.updated_at else None, } diff --git a/api/models/provider.py b/api/models/provider.py index 5d92ee6eb60d18..ff63e03a92e735 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,8 +6,8 @@ class ProviderType(Enum): - CUSTOM = 'custom' - SYSTEM = 'system' + CUSTOM = "custom" + SYSTEM = "system" @staticmethod def value_of(value): @@ -18,13 +18,13 @@ def value_of(value): class ProviderQuotaType(Enum): - PAID = 'paid' + PAID = "paid" """hosted paid quota""" - FREE = 'free' + FREE = "free" """third-party free quota""" - TRIAL = 'trial' + TRIAL = "trial" """hosted trial quota""" @staticmethod @@ -39,27 +39,30 @@ class Provider(db.Model): """ Provider model representing the API providers and their configurations. """ - __tablename__ = 'providers' + + __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_pkey'), - db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + db.PrimaryKeyConstraint("id", name="provider_pkey"), + db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used = db.Column(db.DateTime, nullable=True) quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def __repr__(self): return f"" @@ -67,8 +70,8 @@ def __repr__(self): @property def token_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_config is not None @property @@ -86,118 +89,123 @@ class ProviderModel(db.Model): """ Provider model representing the API provider_models and their configurations. """ - __tablename__ = 'provider_models' + + __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_pkey'), - db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + db.PrimaryKeyConstraint("id", name="provider_model_pkey"), + db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantDefaultModel(db.Model): - __tablename__ = 'tenant_default_models' + __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), - db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPreferredModelProvider(db.Model): - __tablename__ = 'tenant_preferred_model_providers' + __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), - db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderOrder(db.Model): - __tablename__ = 'provider_orders' + __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_order_pkey'), - db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="provider_order_pkey"), + db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) currency = db.Column(db.String(40)) total_amount = db.Column(db.Integer) payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderModelSetting(db.Model): """ Provider model settings for record the model enabled status and load balancing status. """ - __tablename__ = 'provider_model_settings' + + __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), - db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class LoadBalancingModelConfig(db.Model): """ Configurations for load balancing models. """ - __tablename__ = 'load_balancing_model_configs' + + __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), - db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/source.py b/api/models/source.py index adc00028bee43b..07695f06e6cf00 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -8,48 +8,48 @@ class DataSourceOauthBinding(db.Model): - __tablename__ = 'data_source_oauth_bindings' + __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='source_binding_pkey'), - db.Index('source_binding_tenant_id_idx', 'tenant_id'), - db.Index('source_info_idx', "source_info", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="source_binding_pkey"), + db.Index("source_binding_tenant_id_idx", "tenant_id"), + db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(db.Model): - __tablename__ = 'data_source_api_key_auth_bindings' + __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), - db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), - db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), + db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'category': self.category, - 'provider': self.provider, - 'credentials': json.loads(self.credentials), - 'created_at': self.created_at.timestamp(), - 'updated_at': self.updated_at.timestamp(), - 'disabled': self.disabled + "id": self.id, + "tenant_id": self.tenant_id, + "category": self.category, + "provider": self.provider, + "credentials": json.loads(self.credentials), + "created_at": self.created_at.timestamp(), + "updated_at": self.updated_at.timestamp(), + "disabled": self.disabled, } diff --git a/api/models/task.py b/api/models/task.py index 618d831d8ed4e9..57b147c78db110 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,15 +8,18 @@ class CeleryTask(db.Model): """Task result/status.""" - __tablename__ = 'celery_taskmeta' + __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence('task_id_sequence'), - primary_key=True, autoincrement=True) + id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column( + db.DateTime, + default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + nullable=True, + ) traceback = db.Column(db.Text, nullable=True) name = db.Column(db.String(155), nullable=True) args = db.Column(db.LargeBinary, nullable=True) @@ -29,11 +32,9 @@ class CeleryTask(db.Model): class CeleryTaskSet(db.Model): """TaskSet result.""" - __tablename__ = 'celery_tasksetmeta' + __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), - autoincrement=True, primary_key=True) + id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) diff --git a/api/models/tool.py b/api/models/tool.py index 79a70c6b1f2d22..a81bb65174a724 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -7,7 +7,7 @@ class ToolProviderName(Enum): - SERPAPI = 'serpapi' + SERPAPI = "serpapi" @staticmethod def value_of(value): @@ -18,25 +18,25 @@ def value_of(value): class ToolProvider(db.Model): - __tablename__ = 'tool_providers' + __tablename__ = "tool_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_provider_pkey'), - db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), + db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) tool_name = db.Column(db.String(40), nullable=False) encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_credentials is not None @property diff --git a/api/models/tools.py b/api/models/tools.py index 069dc5bad083c8..6b69a219b15184 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -15,15 +15,16 @@ class BuiltinToolProvider(db.Model): """ This table stores the tool provider information for built-in tools for each tenant. """ - __tablename__ = 'tool_builtin_providers' + + __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), # one tenant can only have one tool provider with the same name - db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the tenant tenant_id = db.Column(StringUUID, nullable=True) # who created this tool provider @@ -32,27 +33,29 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials(self) -> dict: return json.loads(self.encrypted_credentials) + class PublishedAppTool(db.Model): """ The table stores the apps published as a tool for each person. """ - __tablename__ = 'tool_published_apps' + + __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), - db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app - app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False) + app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) # who published this tool user_id = db.Column(StringUUID, nullable=False) # description of the tool, stored in i18n format, for human @@ -67,28 +70,30 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) - + @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ApiToolProvider(db.Model): """ The table stores the api providers. """ - __tablename__ = 'tool_api_providers' + + __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider name = db.Column(db.String(40), nullable=False) # icon @@ -111,21 +116,21 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer = db.Column(db.String(255), nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] - + @property def credentials(self) -> dict: return json.loads(self.credentials_str) - + @property def user(self) -> Account: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -134,17 +139,19 @@ def user(self) -> Account: def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + class ToolLabelBinding(db.Model): """ The table stores the labels for tools. """ - __tablename__ = 'tool_label_bindings' + + __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), - db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id tool_id = db.Column(db.String(64), nullable=False) # tool type @@ -152,28 +159,30 @@ class ToolLabelBinding(db.Model): # label name label_name = db.Column(db.String(40), nullable=False) + class WorkflowToolProvider(db.Model): """ The table stores the workflow providers. """ - __tablename__ = 'tool_workflow_providers' + + __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider name = db.Column(db.String(40), nullable=False) # label of the workflow provider - label = db.Column(db.String(255), nullable=False, server_default='') + label = db.Column(db.String(255), nullable=False, server_default="") # icon icon = db.Column(db.String(255), nullable=False) # app id of the workflow provider app_id = db.Column(StringUUID, nullable=False) # version of the workflow provider - version = db.Column(db.String(255), nullable=False, server_default='') + version = db.Column(db.String(255), nullable=False, server_default="") # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -181,17 +190,17 @@ class WorkflowToolProvider(db.Model): # description of the provider description = db.Column(db.Text, nullable=False) # parameter configuration - parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') + parameter_configuration = db.Column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy = db.Column(db.String(255), nullable=True, server_default='') + privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def user(self) -> Account: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -199,28 +208,25 @@ def user(self) -> Account: @property def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - + @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(**config) - for config in json.loads(self.parameter_configuration) - ] - + return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke """ + __tablename__ = "tool_model_invokes" - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -238,29 +244,31 @@ class ToolModelInvoke(db.Model): # invoke response model_response = db.Column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + class ToolConversationVariables(db.Model): """ store the conversation variables from tool invoke """ + __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), + db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index('user_id_idx', 'user_id'), - db.Index('conversation_id_idx', 'conversation_id'), + db.Index("user_id_idx", "user_id"), + db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -270,25 +278,27 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def variables(self) -> dict: return json.loads(self.variables_str) - + + class ToolFile(db.Model): """ store the file created by agent """ + __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_file_pkey'), + db.PrimaryKeyConstraint("id", name="tool_file_pkey"), # add index for conversation_id - db.Index('tool_file_conversation_id_idx', 'conversation_id'), + db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -300,4 +310,4 @@ class ToolFile(db.Model): # mime type mimetype = db.Column(db.String(255), nullable=False) # original url - original_url = db.Column(db.String(2048), nullable=True) \ No newline at end of file + original_url = db.Column(db.String(2048), nullable=True) diff --git a/api/models/types.py b/api/models/types.py index 1614ec20188541..cb6773e70cdd5f 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -9,13 +9,13 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == 'postgresql': + elif dialect.name == "postgresql": return str(value) else: return value.hex def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -23,4 +23,4 @@ def load_dialect_impl(self, dialect): def process_result_value(self, value, dialect): if value is None: return value - return str(value) \ No newline at end of file + return str(value) diff --git a/api/models/web.py b/api/models/web.py index 0e901d5f842691..bc088c185d5a8b 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,3 @@ - from extensions.ext_database import db from .model import Message @@ -6,18 +5,18 @@ class SavedMessage(db.Model): - __tablename__ = 'saved_messages' + __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='saved_message_pkey'), - db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="saved_message_pkey"), + db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def message(self): @@ -25,15 +24,15 @@ def message(self): class PinnedConversation(db.Model): - __tablename__ = 'pinned_conversations' + __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), - db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index e78b5666bcfcdf..d52749f0ff4e5a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,11 +22,12 @@ class CreatedByRole(Enum): """ Created By Role Enum """ - ACCOUNT = 'account' - END_USER = 'end_user' + + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -36,18 +37,19 @@ def value_of(cls, value: str) -> 'CreatedByRole': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid created by role value {value}') + raise ValueError(f"invalid created by role value {value}") class WorkflowType(Enum): """ Workflow Type Enum """ - WORKFLOW = 'workflow' - CHAT = 'chat' + + WORKFLOW = "workflow" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'WorkflowType': + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -57,10 +59,10 @@ def value_of(cls, value: str) -> 'WorkflowType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow type value {value}') + raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -68,6 +70,7 @@ def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': :return: workflow type """ from models.model import AppMode + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT @@ -105,13 +108,13 @@ class Workflow(db.Model): - updated_at (timestamp) `optional` Last update time """ - __tablename__ = 'workflows' + __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), + db.PrimaryKeyConstraint("id", name="workflow_pkey"), + db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = db.Column(StringUUID, nullable=False) type: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -119,15 +122,31 @@ class Workflow(db.Model): graph: Mapped[str] = db.Column(db.Text) features: Mapped[str] = db.Column(db.Text) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by: Mapped[str] = db.Column(StringUUID) updated_at: Mapped[datetime] = db.Column(db.DateTime) - _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + _environment_variables: Mapped[str] = db.Column( + "environment_variables", db.Text, nullable=False, server_default="{}" + ) + _conversation_variables: Mapped[str] = db.Column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) - def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, - features: str, created_by: str, environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable]): + def __init__( + self, + *, + tenant_id: str, + app_id: str, + type: str, + version: str, + graph: str, + features: str, + created_by: str, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ): self.tenant_id = tenant_id self.app_id = app_id self.type = type @@ -160,22 +179,20 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] graph_dict = self.graph_dict - if 'nodes' not in graph_dict: + if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) + start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) if not start_node: return [] # get user_input_form from start node - variables = start_node.get('data', {}).get('variables', []) + variables = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] for variable in variables: - old_structure_variables.append({ - variable['type']: variable - }) + old_structure_variables.append({variable["type"]: variable}) return old_structure_variables @@ -188,25 +205,24 @@ def unique_hash(self) -> str: :return: hash """ - entity = { - 'graph': self.graph_dict, - 'features': self.features_dict - } + entity = {"graph": self.graph_dict, "features": self.features_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property def tool_published(self) -> bool: from models.tools import WorkflowToolProvider - return db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.app_id == self.app_id - ).first() is not None + + return ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() + is not None + ) @property def environment_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: - self._environment_variables = '{}' + self._environment_variables = "{}" tenant_id = contexts.tenant_id.get() @@ -215,9 +231,7 @@ def environment_variables(self) -> Sequence[Variable]: # decrypt secret variables value decrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -230,19 +244,17 @@ def environment_variables(self, value: Sequence[Variable]): value = list(value) if any(var for var in value if not var.id): - raise ValueError('environment variable require a unique id') + raise ValueError("environment variable require a unique id") # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). origin_variables_dictionary = {var.id: var for var in self.environment_variables} for i, variable in enumerate(value): if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: - value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) + value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value encrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -256,15 +268,15 @@ def environment_variables(self, value: Sequence[Variable]): def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: environment_variables = list(self.environment_variables) environment_variables = [ - v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] result = { - 'graph': self.graph_dict, - 'features': self.features_dict, - 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], - 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], + "graph": self.graph_dict, + "features": self.features_dict, + "environment_variables": [var.model_dump(mode="json") for var in environment_variables], + "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], } return result @@ -272,7 +284,7 @@ def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: def conversation_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: - self._conversation_variables = '{}' + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] @@ -290,11 +302,12 @@ class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum """ - DEBUGGING = 'debugging' - APP_RUN = 'app-run' + + DEBUGGING = "debugging" + APP_RUN = "app-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": """ Get value of given mode. @@ -304,20 +317,21 @@ def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run triggered from value {value}') + raise ValueError(f"invalid workflow run triggered from value {value}") class WorkflowRunStatus(Enum): """ Workflow Run Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' - STOPPED = 'stopped' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunStatus': + def value_of(cls, value: str) -> "WorkflowRunStatus": """ Get value of given mode. @@ -327,7 +341,7 @@ def value_of(cls, value: str) -> 'WorkflowRunStatus': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run status value {value}') + raise ValueError(f"invalid workflow run status value {value}") class WorkflowRun(db.Model): @@ -368,14 +382,14 @@ class WorkflowRun(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_runs' + __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), - db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), + db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) sequence_number = db.Column(db.Integer, nullable=False) @@ -388,26 +402,25 @@ class WorkflowRun(db.Model): status = db.Column(db.String(255), nullable=False) outputs = db.Column(db.Text) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_steps = db.Column(db.Integer, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): @@ -422,12 +435,12 @@ def outputs_dict(self): return json.loads(self.outputs) if self.outputs else None @property - def message(self) -> Optional['Message']: + def message(self) -> Optional["Message"]: from models.model import Message - return db.session.query(Message).filter( - Message.app_id == self.app_id, - Message.workflow_run_id == self.id - ).first() + + return ( + db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + ) @property def workflow(self): @@ -435,51 +448,51 @@ def workflow(self): def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'app_id': self.app_id, - 'sequence_number': self.sequence_number, - 'workflow_id': self.workflow_id, - 'type': self.type, - 'triggered_from': self.triggered_from, - 'version': self.version, - 'graph': self.graph_dict, - 'inputs': self.inputs_dict, - 'status': self.status, - 'outputs': self.outputs_dict, - 'error': self.error, - 'elapsed_time': self.elapsed_time, - 'total_tokens': self.total_tokens, - 'total_steps': self.total_steps, - 'created_by_role': self.created_by_role, - 'created_by': self.created_by, - 'created_at': self.created_at, - 'finished_at': self.finished_at, + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "sequence_number": self.sequence_number, + "workflow_id": self.workflow_id, + "type": self.type, + "triggered_from": self.triggered_from, + "version": self.version, + "graph": self.graph_dict, + "inputs": self.inputs_dict, + "status": self.status, + "outputs": self.outputs_dict, + "error": self.error, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "total_steps": self.total_steps, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + "finished_at": self.finished_at, } @classmethod - def from_dict(cls, data: dict) -> 'WorkflowRun': + def from_dict(cls, data: dict) -> "WorkflowRun": return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - app_id=data.get('app_id'), - sequence_number=data.get('sequence_number'), - workflow_id=data.get('workflow_id'), - type=data.get('type'), - triggered_from=data.get('triggered_from'), - version=data.get('version'), - graph=json.dumps(data.get('graph')), - inputs=json.dumps(data.get('inputs')), - status=data.get('status'), - outputs=json.dumps(data.get('outputs')), - error=data.get('error'), - elapsed_time=data.get('elapsed_time'), - total_tokens=data.get('total_tokens'), - total_steps=data.get('total_steps'), - created_by_role=data.get('created_by_role'), - created_by=data.get('created_by'), - created_at=data.get('created_at'), - finished_at=data.get('finished_at'), + id=data.get("id"), + tenant_id=data.get("tenant_id"), + app_id=data.get("app_id"), + sequence_number=data.get("sequence_number"), + workflow_id=data.get("workflow_id"), + type=data.get("type"), + triggered_from=data.get("triggered_from"), + version=data.get("version"), + graph=json.dumps(data.get("graph")), + inputs=json.dumps(data.get("inputs")), + status=data.get("status"), + outputs=json.dumps(data.get("outputs")), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + total_tokens=data.get("total_tokens"), + total_steps=data.get("total_steps"), + created_by_role=data.get("created_by_role"), + created_by=data.get("created_by"), + created_at=data.get("created_at"), + finished_at=data.get("finished_at"), ) @@ -487,11 +500,12 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): """ Workflow Node Execution Triggered From Enum """ - SINGLE_STEP = 'single-step' - WORKFLOW_RUN = 'workflow-run' + + SINGLE_STEP = "single-step" + WORKFLOW_RUN = "workflow-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": """ Get value of given mode. @@ -501,19 +515,20 @@ def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution triggered from value {value}') + raise ValueError(f"invalid workflow node execution triggered from value {value}") class WorkflowNodeExecutionStatus(Enum): """ Workflow Node Execution Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": """ Get value of given mode. @@ -523,7 +538,7 @@ def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution status value {value}') + raise ValueError(f"invalid workflow node execution status value {value}") class WorkflowNodeExecution(db.Model): @@ -574,18 +589,31 @@ class WorkflowNodeExecution(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_node_executions' + __tablename__ = "workflow_node_executions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), - db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'workflow_run_id'), - db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_id'), - db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_execution_id'), + db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + db.Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + db.Index( + "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" + ), + db.Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -602,9 +630,9 @@ class WorkflowNodeExecution(db.Model): outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -612,15 +640,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def inputs_dict(self): @@ -641,15 +668,17 @@ def execution_metadata_dict(self): @property def extras(self): from core.tools.tool_manager import ToolManager + extras = {} if self.execution_metadata_dict: from core.workflow.entities.node_entities import NodeType - if self.node_type == NodeType.TOOL.value and 'tool_info' in self.execution_metadata_dict: - tool_info = self.execution_metadata_dict['tool_info'] - extras['icon'] = ToolManager.get_tool_icon( + + if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + tool_info = self.execution_metadata_dict["tool_info"] + extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, - provider_type=tool_info['provider_type'], - provider_id=tool_info['provider_id'] + provider_type=tool_info["provider_type"], + provider_id=tool_info["provider_id"], ) return extras @@ -659,12 +688,13 @@ class WorkflowAppLogCreatedFrom(Enum): """ Workflow App Log Created From Enum """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - INSTALLED_APP = 'installed-app' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -674,7 +704,7 @@ def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow app log created from value {value}') + raise ValueError(f"invalid workflow app log created from value {value}") class WorkflowAppLog(db.Model): @@ -706,13 +736,13 @@ class WorkflowAppLog(db.Model): - created_at (timestamp) Creation time """ - __tablename__ = 'workflow_app_logs' + __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), - db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -720,7 +750,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def workflow_run(self): @@ -729,26 +759,27 @@ def workflow_run(self): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None class ConversationVariable(db.Model): - __tablename__ = 'workflow_conversation_variables' + __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: self.id = id @@ -757,7 +788,7 @@ def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> self.data = data @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, diff --git a/api/pyproject.toml b/api/pyproject.toml index 16b8c32c2a4a5f..dc7e271ccf4053 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -68,7 +68,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "models/**/*.py", "migrations/**/*", ] From d4dc54447abbf76d6b744e15248809178aacc712 Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:53:44 +0800 Subject: [PATCH 10/13] fix the tooltip in tool nodes (#8215) --- .../model-provider-page/model-modal/Form.tsx | 18 +++++++++--------- .../share/text-generation/result/index.tsx | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index e0b28a6d732e5a..c0a7be68a65a69 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -70,15 +70,15 @@ const Form: FC = ({ const renderField = (formSchema: CredentialFormSchema) => { const tooltip = formSchema.tooltip const tooltipContent = (tooltip && ( - - - {tooltip[language] || tooltip.en_US} -
} - triggerClassName='w-4 h-4' - /> - )) + + {tooltip[language] || tooltip.en_US} +
} + triggerClassName='ml-1 w-4 h-4' + asChild={false} + /> + )) if (formSchema.type === FormTypeEnum.textInput || formSchema.type === FormTypeEnum.secretInput || formSchema.type === FormTypeEnum.textNumber) { const { variable, diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 96fe9f01ef20ec..2d5546f9b4f647 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -219,7 +219,7 @@ const Result: FC = ({ })) }, onIterationNext: () => { - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true const iterations = draft.tracing.find(item => item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! From bb3002b1735c90dcc2ee0d32ef3e12dc72af0ed5 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:21:22 +0800 Subject: [PATCH 11/13] revert page column (#8217) --- api/core/rag/datasource/vdb/vector_factory.py | 2 +- .../nodes/knowledge_retrieval/knowledge_retrieval_node.py | 4 ---- api/core/workflow/nodes/llm/llm_node.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 7d2db140df2ad1..bb24143a41ac8d 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -27,7 +27,7 @@ def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict class Vector: def __init__(self, dataset: Dataset, attributes: list = None): if attributes is None: - attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 19deca162aed24..53e8be64153d99 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -163,9 +163,6 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: for item in all_documents: if item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - # both 'page' and 'score' are metadata fields - if item.metadata.get("page"): - page_number_list[item.metadata["doc_id"]] = item.metadata["page"] index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( @@ -200,7 +197,6 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "document_id": document.id, "document_name": document.name, "document_data_source_type": document.data_source_type, - "page": page_number_list.get(segment.index_node_id, None), "segment_id": segment.id, "retriever_from": "workflow", "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 6dfd27861eefc4..049c2114882830 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -451,7 +451,6 @@ def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optiona "segment_position": metadata.get("segment_position"), "index_node_hash": metadata.get("segment_index_node_hash"), "content": context_dict.get("content"), - "page": metadata.get("page"), } return source From ffd4bf8bf0cb93d9b841f1c8a0f5624ee41ca967 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 10 Sep 2024 18:47:59 +0800 Subject: [PATCH 12/13] fix(docker/docker-compose.yaml): Set default value for `REDIS_SENTINEL_SOCKET_TIMEOUT` and `CELERY_SENTINEL_SOCKET_TIMEOUT` (#8218) --- docker/docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 68f897ddb92a6b..5afb876a1cf826 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -47,12 +47,12 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} - REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-} + REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} - CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-} + CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-local} From 7e88556a7658044503af5499a06aa321c58da03a Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Sep 2024 19:09:04 +0800 Subject: [PATCH 13/13] fix(workflow): parallel run error in multi condition branches --- .../workflow/graph_engine/entities/graph.py | 269 +++++++++++------- 1 file changed, 167 insertions(+), 102 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index b54595f7802d9b..298482de1b48e2 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -304,11 +304,14 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = [] + parallel_branch_node_ids = {} condition_edge_mappings = {} for graph_edge in target_node_edges: if graph_edge.run_condition is None: - parallel_branch_node_ids.append(graph_edge.target_node_id) + if "default" not in parallel_branch_node_ids: + parallel_branch_node_ids["default"] = [] + + parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash if not condition_hash in condition_edge_mappings: @@ -316,120 +319,182 @@ def _recursively_add_parallels( condition_edge_mappings[condition_hash].append(graph_edge) - for _, graph_edges in condition_edge_mappings.items(): + for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: + if condition_hash not in parallel_branch_node_ids: + parallel_branch_node_ids[condition_hash] = [] + for graph_edge in graph_edges: - parallel_branch_node_ids.append(graph_edge.target_node_id) + parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) + + condition_parallels = {} + for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items(): + # any target node id in node_parallel_mapping + parallel = None + if condition_parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + ) + parallel_mapping[parallel.id] = parallel + condition_parallels[condition_hash] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=condition_parallel_branch_node_ids, + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue - # any target node id in node_parallel_mapping - if parallel_branch_node_ids: - parent_parallel_id = parent_parallel.id if parent_parallel else None + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue - parallel = GraphParallel( - start_from_node_id=start_node_id, - parent_parallel_id=parent_parallel.id if parent_parallel else None, - parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + if condition_edge_mappings: + for condition_hash, graph_edges in condition_edge_mappings.items(): + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, ) - parallel_mapping[parallel.id] = parallel - in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + cls._recursively_add_parallels( edge_mapping=edge_mapping, reverse_edge_mapping=reverse_edge_mapping, - parallel_branch_node_ids=parallel_branch_node_ids, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, ) - # collect all branches node ids - parallel_node_ids = [] - for _, node_ids in in_branch_node_ids.items(): - for node_id in node_ids: - in_parent_parallel = True - if parent_parallel_id: - in_parent_parallel = False - for parallel_node_id, parallel_id in node_parallel_mapping.items(): - if parallel_id == parent_parallel_id and parallel_node_id == node_id: - in_parent_parallel = True - break - - if in_parent_parallel: - parallel_node_ids.append(node_id) - node_parallel_mapping[node_id] = parallel.id - - outside_parallel_target_node_ids = set() - for node_id in parallel_node_ids: - if node_id == parallel.start_from_node_id: - continue - - node_edges = edge_mapping.get(node_id) - if not node_edges: - continue - - if len(node_edges) > 1: - continue - - target_node_id = node_edges[0].target_node_id - if target_node_id in parallel_node_ids: - continue - - if parent_parallel_id: - parent_parallel = parallel_mapping.get(parent_parallel_id) - if not parent_parallel: - continue - - if ( - ( - node_parallel_mapping.get(target_node_id) - and node_parallel_mapping.get(target_node_id) == parent_parallel_id - ) + @classmethod + def _get_current_parallel( + cls, + parallel_mapping: dict[str, GraphParallel], + graph_edge: GraphEdge, + parallel: Optional[GraphParallel] = None, + parent_parallel: Optional[GraphParallel] = None, + ) -> Optional[GraphParallel]: + """ + Get current parallel + """ + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id or ( - parent_parallel - and parent_parallel.end_to_node_id - and target_node_id == parent_parallel.end_to_node_id + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id ) - or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) ): - outside_parallel_target_node_ids.add(target_node_id) + current_parallel = parent_parallel_parent_parallel - if len(outside_parallel_target_node_ids) == 1: - if ( - parent_parallel - and parent_parallel.end_to_node_id - and parallel.end_to_node_id == parent_parallel.end_to_node_id - ): - parallel.end_to_node_id = None - else: - parallel.end_to_node_id = outside_parallel_target_node_ids.pop() - - for graph_edge in target_node_edges: - current_parallel = None - if parallel: - current_parallel = parallel - elif parent_parallel: - if not parent_parallel.end_to_node_id or ( - parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id - ): - current_parallel = parent_parallel - else: - # fetch parent parallel's parent parallel - parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id - if parent_parallel_parent_parallel_id: - parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) - if parent_parallel_parent_parallel and ( - not parent_parallel_parent_parallel.end_to_node_id - or ( - parent_parallel_parent_parallel.end_to_node_id - and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id - ) - ): - current_parallel = parent_parallel_parent_parallel - - cls._recursively_add_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - start_node_id=graph_edge.target_node_id, - parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel, - ) + return current_parallel @classmethod def _check_exceed_parallel_limit(