From f9e8ea1ce41a2998fabbd8039213e8761012f7f0 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Sep 2024 21:09:18 +0800 Subject: [PATCH] fix(workflow): in multi-parallel execution with multiple conditional branches (#8221) --- .../workflow/graph_engine/entities/graph.py | 269 +++++++++++------- 1 file changed, 167 insertions(+), 102 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index b54595f7802d9b..298482de1b48e2 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -304,11 +304,14 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = [] + parallel_branch_node_ids = {} condition_edge_mappings = {} for graph_edge in target_node_edges: if graph_edge.run_condition is None: - parallel_branch_node_ids.append(graph_edge.target_node_id) + if "default" not in parallel_branch_node_ids: + parallel_branch_node_ids["default"] = [] + + parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash if not condition_hash in condition_edge_mappings: @@ -316,120 +319,182 @@ def _recursively_add_parallels( condition_edge_mappings[condition_hash].append(graph_edge) - for _, graph_edges in condition_edge_mappings.items(): + for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: + if condition_hash not in parallel_branch_node_ids: + parallel_branch_node_ids[condition_hash] = [] + for graph_edge in graph_edges: - parallel_branch_node_ids.append(graph_edge.target_node_id) + parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) + + condition_parallels = {} + for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items(): + # any target node id in node_parallel_mapping + parallel = None + if condition_parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + ) + parallel_mapping[parallel.id] = parallel + condition_parallels[condition_hash] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=condition_parallel_branch_node_ids, + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue - # any target node id in node_parallel_mapping - if parallel_branch_node_ids: - parent_parallel_id = parent_parallel.id if parent_parallel else None + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue - parallel = GraphParallel( - start_from_node_id=start_node_id, - parent_parallel_id=parent_parallel.id if parent_parallel else None, - parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + if condition_edge_mappings: + for condition_hash, graph_edges in condition_edge_mappings.items(): + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, ) - parallel_mapping[parallel.id] = parallel - in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + cls._recursively_add_parallels( edge_mapping=edge_mapping, reverse_edge_mapping=reverse_edge_mapping, - parallel_branch_node_ids=parallel_branch_node_ids, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, ) - # collect all branches node ids - parallel_node_ids = [] - for _, node_ids in in_branch_node_ids.items(): - for node_id in node_ids: - in_parent_parallel = True - if parent_parallel_id: - in_parent_parallel = False - for parallel_node_id, parallel_id in node_parallel_mapping.items(): - if parallel_id == parent_parallel_id and parallel_node_id == node_id: - in_parent_parallel = True - break - - if in_parent_parallel: - parallel_node_ids.append(node_id) - node_parallel_mapping[node_id] = parallel.id - - outside_parallel_target_node_ids = set() - for node_id in parallel_node_ids: - if node_id == parallel.start_from_node_id: - continue - - node_edges = edge_mapping.get(node_id) - if not node_edges: - continue - - if len(node_edges) > 1: - continue - - target_node_id = node_edges[0].target_node_id - if target_node_id in parallel_node_ids: - continue - - if parent_parallel_id: - parent_parallel = parallel_mapping.get(parent_parallel_id) - if not parent_parallel: - continue - - if ( - ( - node_parallel_mapping.get(target_node_id) - and node_parallel_mapping.get(target_node_id) == parent_parallel_id - ) + @classmethod + def _get_current_parallel( + cls, + parallel_mapping: dict[str, GraphParallel], + graph_edge: GraphEdge, + parallel: Optional[GraphParallel] = None, + parent_parallel: Optional[GraphParallel] = None, + ) -> Optional[GraphParallel]: + """ + Get current parallel + """ + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id or ( - parent_parallel - and parent_parallel.end_to_node_id - and target_node_id == parent_parallel.end_to_node_id + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id ) - or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) ): - outside_parallel_target_node_ids.add(target_node_id) + current_parallel = parent_parallel_parent_parallel - if len(outside_parallel_target_node_ids) == 1: - if ( - parent_parallel - and parent_parallel.end_to_node_id - and parallel.end_to_node_id == parent_parallel.end_to_node_id - ): - parallel.end_to_node_id = None - else: - parallel.end_to_node_id = outside_parallel_target_node_ids.pop() - - for graph_edge in target_node_edges: - current_parallel = None - if parallel: - current_parallel = parallel - elif parent_parallel: - if not parent_parallel.end_to_node_id or ( - parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id - ): - current_parallel = parent_parallel - else: - # fetch parent parallel's parent parallel - parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id - if parent_parallel_parent_parallel_id: - parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) - if parent_parallel_parent_parallel and ( - not parent_parallel_parent_parallel.end_to_node_id - or ( - parent_parallel_parent_parallel.end_to_node_id - and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id - ) - ): - current_parallel = parent_parallel_parent_parallel - - cls._recursively_add_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - start_node_id=graph_edge.target_node_id, - parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel, - ) + return current_parallel @classmethod def _check_exceed_parallel_limit(