Skip to content

Commit

Permalink
refactor(workflow): introduce specific error handling for LLM nodes
Browse files Browse the repository at this point in the history
- Define custom exception classes for various LLM node errors.
- Replace generic ValueErrors with specific exceptions to improve error handling.
- Enhance clarity and maintainability by categorizing errors, aiding debugging and troubleshooting.
  • Loading branch information
laipz8200 committed Nov 3, 2024
1 parent 2ed6bb8 commit 04b7a5b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
26 changes: 26 additions & 0 deletions api/core/workflow/nodes/llm/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
class LLMNodeError(ValueError):
"""Base class for LLM Node errors."""


class VariableNotFoundError(LLMNodeError):
"""Raised when a required variable is not found."""


class InvalidContextStructureError(LLMNodeError):
"""Raised when the context structure is invalid."""


class InvalidVariableTypeError(LLMNodeError):
"""Raised when the variable type is invalid."""


class ModelNotExistError(LLMNodeError):
"""Raised when the specified model does not exist."""


class LLMModeRequiredError(LLMNodeError):
"""Raised when LLM mode is required but not provided."""


class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""
33 changes: 21 additions & 12 deletions api/core/workflow/nodes/llm/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
ModelNotExistError,
NoPromptFoundError,
VariableNotFoundError,
)

if TYPE_CHECKING:
from core.file.models import File
Expand Down Expand Up @@ -115,7 +124,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
raise VariableNotFoundError("Query not found")
query = query.text
else:
query = None
Expand Down Expand Up @@ -161,7 +170,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
except LLMNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
Expand Down Expand Up @@ -275,7 +284,7 @@ def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")

def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
Expand Down Expand Up @@ -325,7 +334,7 @@ def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object()
Expand All @@ -338,7 +347,7 @@ def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
continue
inputs[variable_selector.variable] = variable.to_object()
Expand All @@ -355,7 +364,7 @@ def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")

def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
Expand All @@ -376,7 +385,7 @@ def _fetch_context(self, node_data: LLMNodeData):
context_str += item + "\n"
else:
if "content" not in item:
raise ValueError(f"Invalid context structure: {item}")
raise InvalidContextStructureError(f"Invalid context structure: {item}")

context_str += item["content"] + "\n"

Expand Down Expand Up @@ -441,7 +450,7 @@ def _fetch_model_config(
)

if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
Expand All @@ -460,12 +469,12 @@ def _fetch_model_config(
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise LLMModeRequiredError("LLM mode is required.")

model_schema = model_type_instance.get_model_schema(model_name, model_credentials)

if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")

return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
Expand Down Expand Up @@ -564,7 +573,7 @@ def _fetch_prompt_messages(
filtered_prompt_messages.append(prompt_message)

if not filtered_prompt_messages:
raise ValueError(
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
Expand Down Expand Up @@ -636,7 +645,7 @@ def _extract_variable_selector_to_variable_mapping(
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")

variable_mapping = {}
for variable_selector in variable_selectors:
Expand Down

0 comments on commit 04b7a5b

Please sign in to comment.