From 85d319719cec904cb89cb0ea2dee19f91ae98bfd Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 24 Aug 2024 17:17:18 +0800 Subject: [PATCH] fix end node bug --- api/core/app/apps/workflow_app_runner.py | 3 +- api/core/app/entities/queue_entities.py | 2 + .../nodes/answer/answer_stream_processor.py | 2 +- .../nodes/end/end_stream_generate_router.py | 11 +- .../nodes/end/end_stream_processor.py | 117 ++++++++++++++++-- 5 files changed, 120 insertions(+), 15 deletions(-) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index a07211f5817ccf..ec85412c1e62dd 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -255,7 +255,8 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) elif isinstance(event, NodeRunStreamChunkEvent): self._publish_event( QueueTextChunkEvent( - text=event.chunk_content + text=event.chunk_content, + from_variable_selector=event.from_variable_selector ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 04226636d10005..e4d2ab44d5b5a3 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -150,6 +150,8 @@ class QueueTextChunkEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.TEXT_CHUNK text: str + from_variable_selector: Optional[list[str]] = None + """from variable selector""" class QueueAgentMessageEvent(AppQueueEvent): diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 40363df0f5632f..c2a5dd5163819d 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -203,7 +203,7 @@ def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: return files @classmethod - def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]: + def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: """ Get file var from value :param value: variable value diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 77d1d5efb08f07..8390f6d81b4d31 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -62,14 +62,13 @@ def extract_stream_variable_selector_from_node_data(cls, if node_id != 'sys' and node_id in node_id_config_mapping: node = node_id_config_mapping[node_id] node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': + if ( + variable_selector.value_selector not in value_selectors + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == 'text' + ): value_selectors.append(variable_selector.value_selector) - # remove duplicates - value_selector_tuples = [tuple(item) for item in value_selectors] - unique_value_selector_tuples = list(set(value_selector_tuples)) - value_selectors = [list(item) for item in unique_value_selector_tuples] - return value_selectors @classmethod diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 4232c355882260..4474c2a78a2740 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -18,9 +18,13 @@ class EndStreamProcessor(StreamProcessor): def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) - self.stream_param = graph.end_stream_param - self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy() + self.end_stream_param = graph.end_stream_param + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + self.has_outputed = False + self.outputed_node_ids = set() def process(self, generator: Generator[GraphEngineEvent, None, None] @@ -32,6 +36,15 @@ def process(self, yield event elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + continue + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id @@ -42,23 +55,97 @@ def process(self, event.route_node_state.node_id ] = stream_out_end_node_ids - for _ in stream_out_end_node_ids: + if stream_out_end_node_ids: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True yield event elif isinstance(event, NodeRunSucceededEvent): yield event if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[end_node_id] += 1 + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # remove unreachable nodes self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) else: yield event def reset(self) -> None: - self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy() + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for end_node_id, position in self.route_position.items(): + # all depends on end node id not in rest node ids + if (event.route_node_state.node_id != end_node_id + and (end_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + continue + + route_position = self.route_position[end_node_id] + + position = 0 + value_selectors = [] + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position >= route_position: + value_selectors.append(current_value_selectors) + + position += 1 + + for value_selector in value_selectors: + if not value_selector: + continue + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + current_node_id = value_selector[0] + if self.has_outputed and current_node_id not in self.outputed_node_ids: + text = '\n' + text + + self.outputed_node_ids.add(current_node_id) + self.has_outputed = True + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[end_node_id] += 1 + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: """ Is stream out support @@ -73,14 +160,30 @@ def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[s return [] stream_out_end_node_ids = [] - for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items(): + for end_node_id, route_position in self.route_position.items(): if end_node_id not in self.rest_node_ids: continue # all depends on end node id not in rest node ids if all(dep_id not in self.rest_node_ids - for dep_id in self.stream_param.end_dependencies[end_node_id]): - if stream_output_value_selector not in variable_selectors: + for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): + continue + + position = 0 + value_selector = None + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position == route_position: + value_selector = current_value_selectors + break + + position += 1 + + if not value_selector: + continue + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: continue stream_out_end_node_ids.append(end_node_id)