diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 691d178ba2aefa..c6c4923ee684f2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -68,7 +68,6 @@ from models.workflow import ( Workflow, WorkflowNodeExecution, - WorkflowRun, WorkflowRunStatus, ) @@ -104,10 +103,12 @@ def __init__( ) if isinstance(user, EndUser): - self._user_id = user.session_id + self._user_id = user.id + user_session_id = user.session_id self._created_by_role = CreatedByRole.END_USER elif isinstance(user, Account): self._user_id = user.id + user_session_id = user.id self._created_by_role = CreatedByRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") @@ -125,7 +126,7 @@ def __init__( SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: self._user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, @@ -137,6 +138,7 @@ def __init__( self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] + self._workflow_run_id = "" def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -266,7 +268,6 @@ def _process_stream_response( """ # init fake graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -291,111 +292,163 @@ def _process_stream_response( user_id=self._user_id, created_by_role=self._created_by_role, ) + self._workflow_run_id = workflow_run.id message = self._get_message(session=session) if not message: raise ValueError(f"Message not found: {self._message_id}") message.workflow_run_id = workflow_run.id - session.commit() - workflow_start_resp = self._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + session.commit() + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - node_retry_resp = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + node_retry_resp = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_retry_resp: yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) - node_start_resp = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + node_start_resp = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_start_resp: yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - node_finish_resp = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + + node_finish_resp = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_finish_resp: yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) + + node_finish_resp = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - node_finish_resp = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) if node_finish_resp: yield node_finish_resp - elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") if not graph_runtime_state: @@ -404,7 +457,7 @@ def _process_stream_response( with Session(db.engine) as session: workflow_run = self._handle_workflow_run_success( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -421,16 +474,15 @@ def _process_stream_response( yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") with Session(db.engine) as session: workflow_run = self._handle_workflow_run_partial_success( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -439,7 +491,6 @@ def _process_stream_response( conversation_id=None, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) @@ -448,16 +499,15 @@ def _process_stream_response( yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") with Session(db.engine) as session: workflow_run = self._handle_workflow_run_failed( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -473,15 +523,16 @@ def _process_stream_response( err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) err = self._handle_error(event=err_event, session=session, message_id=self._message_id) session.commit() + yield workflow_finish_resp yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: + if self._workflow_run_id and graph_runtime_state: with Session(db.engine) as session: workflow_run = self._handle_workflow_run_failed( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -490,7 +541,6 @@ def _process_stream_response( conversation_id=self._conversation_id, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, @@ -499,6 +549,7 @@ def _process_stream_response( # Save message self._save_message(session=session, graph_runtime_state=graph_runtime_state) session.commit() + yield workflow_finish_resp yield self._message_end_to_stream_response() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 574596d4f5a77c..c447f9c2fc1515 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -91,10 +91,12 @@ def __init__( ) if isinstance(user, EndUser): - self._user_id = user.session_id + self._user_id = user.id + user_session_id = user.session_id self._created_by_role = CreatedByRole.END_USER elif isinstance(user, Account): self._user_id = user.id + user_session_id = user.id self._created_by_role = CreatedByRole.ACCOUNT else: raise ValueError(f"Invalid user type: {type(user)}") @@ -104,7 +106,7 @@ def __init__( self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: self._user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, @@ -112,6 +114,7 @@ def __init__( self._task_state = WorkflowTaskState() self._wip_workflow_node_executions = {} + self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -233,7 +236,6 @@ def _process_stream_response( :return: """ graph_runtime_state = None - workflow_run = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -256,111 +258,168 @@ def _process_stream_response( user_id=self._user_id, created_by_role=self._created_by_role, ) + self._workflow_run_id = workflow_run.id start_resp = self._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() + yield start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + response = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if response: yield response elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - - node_start_response = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + node_start_response = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - - node_success_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + node_success_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed( + session=session, + event=event, + ) + node_failed_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - node_failed_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) if node_failed_response: yield node_failed_response elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") with Session(db.engine) as session: workflow_run = self._handle_workflow_run_success( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -378,18 +437,18 @@ def _process_stream_response( workflow_run=workflow_run, ) session.commit() + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") with Session(db.engine) as session: workflow_run = self._handle_workflow_run_partial_success( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -409,15 +468,15 @@ def _process_stream_response( yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") + with Session(db.engine) as session: workflow_run = self._handle_workflow_run_failed( session=session, - workflow_run=workflow_run, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, @@ -437,6 +496,7 @@ def _process_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2692008c6653d6..7215105a5652da 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -46,7 +46,6 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -66,7 +65,6 @@ class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def _handle_workflow_run_start( self, @@ -130,7 +128,7 @@ def _handle_workflow_run_success( self, *, session: Session, - workflow_run: WorkflowRun, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -148,7 +146,7 @@ def _handle_workflow_run_success( :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -175,7 +173,7 @@ def _handle_workflow_run_partial_success( self, *, session: Session, - workflow_run: WorkflowRun, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -184,18 +182,7 @@ def _handle_workflow_run_partial_success( conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) - + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value @@ -222,7 +209,7 @@ def _handle_workflow_run_failed( self, *, session: Session, - workflow_run: WorkflowRun, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -242,7 +229,7 @@ def _handle_workflow_run_failed( :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) workflow_run.status = status.value workflow_run.error = error @@ -284,49 +271,41 @@ def _handle_workflow_run_failed( return workflow_run def _handle_node_execution_start( - self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: - # init workflow node execution - - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - } - ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_node_execution) - session.commit() - session.refresh(workflow_node_execution) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + session.add(workflow_node_execution) return workflow_node_execution - def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param event: queue node succeeded event - :return: - """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) - + def _handle_workflow_node_execution_success( + self, *, session: Session, event: QueueNodeSucceededEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -336,20 +315,6 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.execution_metadata: execution_metadata, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value @@ -360,19 +325,22 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent + self, + *, + session: Session, + event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent, ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -382,25 +350,6 @@ def _handle_workflow_node_execution_failed( execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: ( - WorkflowNodeExecutionStatus.FAILED.value - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value - ), - WorkflowNodeExecution.error: event.error, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - WorkflowNodeExecution.execution_metadata: execution_metadata, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = ( WorkflowNodeExecutionStatus.FAILED.value @@ -415,12 +364,10 @@ def _handle_workflow_node_execution_failed( workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_retried( - self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -444,6 +391,7 @@ def _handle_workflow_node_execution_retried( execution_metadata = json.dumps(merged_metadata) workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.workflow_id = workflow_run.workflow_id @@ -466,10 +414,7 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - + session.add(workflow_node_execution) return workflow_node_execution ################################################# @@ -547,17 +492,20 @@ def _workflow_finish_to_stream_response( ) def _workflow_node_start_to_stream_response( - self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + self, + *, + session: Session, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - """ - Workflow node start to stream response. - :param event: queue node started event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None response = NodeStartStreamResponse( task_id=task_id, @@ -593,6 +541,8 @@ def _workflow_node_start_to_stream_response( def _workflow_node_finish_to_stream_response( self, + *, + session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -600,15 +550,14 @@ def _workflow_node_finish_to_stream_response( task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeFinishStreamResponse( task_id=task_id, @@ -640,19 +589,20 @@ def _workflow_node_finish_to_stream_response( def _workflow_node_retry_to_stream_response( self, + *, + session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeRetryStreamResponse( task_id=task_id, @@ -684,15 +634,10 @@ def _workflow_node_retry_to_stream_response( ) def _workflow_parallel_branch_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - """ - Workflow parallel branch start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run started event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -708,17 +653,14 @@ def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_finished_to_stream_response( self, + *, + session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - """ - Workflow parallel branch finished to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run succeeded or failed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -735,15 +677,10 @@ def _workflow_parallel_branch_finished_to_stream_response( ) def _workflow_iteration_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - """ - Workflow iteration start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration start event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -762,15 +699,10 @@ def _workflow_iteration_start_to_stream_response( ) def _workflow_iteration_next_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - """ - Workflow iteration next to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration next event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeNextStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -791,15 +723,10 @@ def _workflow_iteration_next_to_stream_response( ) def _workflow_iteration_completed_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - """ - Workflow iteration completed to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration completed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -883,7 +810,7 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any return None - def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: + def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run :param workflow_run_id: workflow run id @@ -896,14 +823,9 @@ def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> Wo return workflow_run - def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - """ - Refetch workflow node execution - :param node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) - + def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id) + workflow_node_execution = session.scalar(stmt) if not workflow_node_execution: raise WorkflowNodeExecutionNotFoundError(node_execution_id) diff --git a/api/models/workflow.py b/api/models/workflow.py index 78a7f8169fe634..32a0860b77bbea 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -400,11 +400,11 @@ class WorkflowRun(db.Model): # type: ignore[name-defined] type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) version: Mapped[str] = mapped_column(db.String(255)) - graph: Mapped[str] = mapped_column(db.Text) - inputs: Mapped[str] = mapped_column(db.Text) + graph: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[str] = mapped_column(db.Text) + error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) @@ -609,29 +609,29 @@ class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - workflow_run_id = db.Column(StringUUID) - index = db.Column(db.Integer, nullable=False) - predecessor_node_id = db.Column(db.String(255)) - node_execution_id = db.Column(db.String(255), nullable=True) - node_id = db.Column(db.String(255), nullable=False) - node_type = db.Column(db.String(255), nullable=False) - title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text) - process_data = db.Column(db.Text) - outputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - finished_at = db.Column(db.DateTime) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id: Mapped[str] = mapped_column(StringUUID) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + index: Mapped[int] = mapped_column(db.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_id: Mapped[str] = mapped_column(db.String(255)) + node_type: Mapped[str] = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(db.String(255)) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(db.Text) + outputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_by: Mapped[str] = mapped_column(StringUUID) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) @property def created_by_account(self): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ea8192edde35cc..81b197a2478992 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from datetime import UTC, datetime from typing import Any, Optional, cast +from uuid import uuid4 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -277,6 +278,7 @@ def run_draft_workflow_node( error = e.error workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) workflow_node_execution.tenant_id = app_model.tenant_id workflow_node_execution.app_id = app_model.id workflow_node_execution.workflow_id = draft_workflow.id