Skip to content

Commit

Permalink
fix(workflow): fix merge branch node id err
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Sep 2, 2024
1 parent 29b1ce7 commit 0dabf79
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def init(cls,
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
Expand Down Expand Up @@ -310,6 +311,7 @@ def _check_connected_to_previous_node(
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str]) -> None:
Expand Down Expand Up @@ -365,6 +367,7 @@ def _recursively_add_parallels(cls,

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
)

Expand Down Expand Up @@ -412,6 +415,7 @@ def _recursively_add_parallels(cls,
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
Expand Down Expand Up @@ -472,6 +476,7 @@ def _recursively_add_parallel_node_ids(cls,
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
Expand Down Expand Up @@ -499,7 +504,11 @@ def _fetch_all_node_ids_in_parallels(cls,
leaf_node_ids[branch_node_id].append(node_id)

for branch_node_id2, inner_route2 in routes_node_ids.items():
if branch_node_id != branch_node_id2 and node_id in inner_route2:
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []

Expand Down

0 comments on commit 0dabf79

Please sign in to comment.