From 5bda3a384a33b8c539d04f879804e887721fe0b6 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 2 Sep 2024 17:49:51 +0800 Subject: [PATCH] fix(workflow): bugs --- .../workflow/graph_engine/entities/graph.py | 75 ++++++++++++------- .../workflow/graph_engine/graph_engine.py | 8 +- .../core/workflow/graph_engine/test_graph.py | 5 ++ 3 files changed, 60 insertions(+), 28 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 48ae33f29a18c4..5d925e5d8c650f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -309,12 +309,15 @@ def _check_connected_to_previous_node( ) @classmethod - def _recursively_add_parallels(cls, - edge_mapping: dict[str, list[GraphEdge]], - reverse_edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - parallel_mapping: dict[str, GraphParallel], - node_parallel_mapping: dict[str, str]) -> None: + def _recursively_add_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str], + parent_parallel: Optional[GraphParallel] = None + ) -> None: """ Recursively add parallel ids @@ -322,8 +325,10 @@ def _recursively_add_parallels(cls, :param start_node_id: start from node id :param parallel_mapping: parallel mapping :param node_parallel_mapping: node parallel mapping + :param parent_parallel: parent parallel """ target_node_edges = edge_mapping.get(start_node_id, []) + parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels parallel_branch_node_ids = [] @@ -345,18 +350,7 @@ def _recursively_add_parallels(cls, # any target node id in node_parallel_mapping if parallel_branch_node_ids: - # all parallel_branch_node_ids in node_parallel_mapping - parent_parallel_id = None - for node_id in parallel_branch_node_ids: - if node_id in node_parallel_mapping: - parent_parallel_id = node_parallel_mapping[node_id] - break - - parent_parallel = None - if parent_parallel_id: - parent_parallel = parallel_mapping.get(parent_parallel_id) - if not parent_parallel: - raise Exception(f"Parent parallel {parent_parallel_id} not found") + parent_parallel_id = parent_parallel.id if parent_parallel else None parallel = GraphParallel( start_from_node_id=start_node_id, @@ -375,8 +369,17 @@ def _recursively_add_parallels(cls, parallel_node_ids = [] for _, node_ids in in_branch_node_ids.items(): for node_id in node_ids: - parallel_node_ids.append(node_id) - node_parallel_mapping[node_id] = parallel.id + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id outside_parallel_target_node_ids = set() for node_id in parallel_node_ids: @@ -418,7 +421,8 @@ def _recursively_add_parallels(cls, reverse_edge_mapping=reverse_edge_mapping, start_node_id=graph_edge.target_node_id, parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, + parent_parallel=parallel if parallel else parent_parallel ) @classmethod @@ -538,14 +542,14 @@ def _fetch_all_node_ids_in_parallels(cls, edge_mapping=edge_mapping ): if node_id in merge_branch_node_ids: - del merge_branch_node_ids[node_id] + del merge_branch_node_ids[node_id2] elif cls._is_node2_after_node1( node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping ): if node_id2 in merge_branch_node_ids: - del merge_branch_node_ids[node_id2] + del merge_branch_node_ids[node_id] branches_merge_node_ids: dict[str, str] = {} for node_id, branch_node_ids in merge_branch_node_ids.items(): @@ -613,13 +617,30 @@ def _is_node_in_routes(cls, if start_node_id not in reverse_edge_mapping: return False - all_routes_node_ids = [] - for _, node_ids in routes_node_ids.items(): + all_routes_node_ids = set() + parallel_start_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): for node_id in node_ids: - all_routes_node_ids.append(node_id) + all_routes_node_ids.add(node_id) + + if branch_node_id in reverse_edge_mapping: + for graph_edge in reverse_edge_mapping[branch_node_id]: + if graph_edge.source_node_id not in parallel_start_node_ids: + parallel_start_node_ids[graph_edge.source_node_id] = [] + + parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) + + parallel_start_node_id = None + for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + if set(branch_node_ids) == set(routes_node_ids.keys()): + parallel_start_node_id = p_start_node_id + return True + + if not parallel_start_node_id: + raise Exception("Parallel start node id not found") for graph_edge in reverse_edge_mapping[start_node_id]: - if graph_edge.source_node_id not in all_routes_node_ids: + if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: return False return True diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f0538e26710200..b3b64722c5c038 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -352,7 +352,13 @@ def _run_parallel_branches( # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) if not parallel_id: - raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.') + node_id = edge_mappings[0].target_node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f'Node {node_id} related parallel not found.') + + node_title = node_config.get('data', {}).get('title') + raise GraphRunFailedError(f'Node {node_title} related parallel not found.') parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py index a402dc0845c79d..65757cd604cfa8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -700,6 +700,11 @@ def test_parallels_graph6(): "source": "code3", "target": "answer", }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, ], "nodes": [ {"data": {"type": "start"}, "id": "start"},