diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 00000000000000..f858be25156951 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b4728e6abf6800..34f44eab067943 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -56,6 +56,15 @@ LLMNodeData, ModelConfig, ) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) if TYPE_CHECKING: from core.file.models import File @@ -115,7 +124,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] if self.node_data.memory: query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) if not query: - raise ValueError("Query not found") + raise VariableNotFoundError("Query not found") query = query.text else: query = None @@ -161,7 +170,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] usage = event.usage finish_reason = event.finish_reason break - except Exception as e: + except LLMNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -275,7 +284,7 @@ def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ @@ -325,7 +334,7 @@ def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() @@ -338,7 +347,7 @@ def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() @@ -355,7 +364,7 @@ def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: return variable.value elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] - raise ValueError(f"Invalid variable type: {type(variable)}") + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: @@ -376,7 +385,7 @@ def _fetch_context(self, node_data: LLMNodeData): context_str += item + "\n" else: if "content" not in item: - raise ValueError(f"Invalid context structure: {item}") + raise InvalidContextStructureError(f"Invalid context structure: {item}") context_str += item["content"] + "\n" @@ -441,7 +450,7 @@ def _fetch_model_config( ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") @@ -460,12 +469,12 @@ def _fetch_model_config( # get model mode model_mode = node_data_model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise LLMModeRequiredError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, @@ -564,7 +573,7 @@ def _fetch_prompt_messages( filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError( + raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) @@ -636,7 +645,7 @@ def _extract_variable_selector_to_variable_mapping( variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: - raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: