Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(workflow): in multi-parallel execution with multiple conditional branches #8221

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 167 additions & 102 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,132 +304,197 @@ def _recursively_add_parallels(
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
parallel_branch_node_ids = {}
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id)
if "default" not in parallel_branch_node_ids:
parallel_branch_node_ids["default"] = []

parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []

condition_edge_mappings[condition_hash].append(graph_edge)

for _, graph_edges in condition_edge_mappings.items():
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
if condition_hash not in parallel_branch_node_ids:
parallel_branch_node_ids[condition_hash] = []

for graph_edge in graph_edges:
parallel_branch_node_ids.append(graph_edge.target_node_id)
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)

condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping
parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None

parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=condition_parallel_branch_node_ids,
)

# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break

if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id

outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue

node_edges = edge_mapping.get(node_id)
if not node_edges:
continue

if len(node_edges) > 1:
continue

# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue

parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue

if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)

if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()

if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
parallel_mapping[parallel.id] = parallel

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)

# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break

if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id

outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue

node_edges = edge_mapping.get(node_id)
if not node_edges:
continue

if len(node_edges) > 1:
continue

target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue

if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue

if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: Optional[GraphParallel] = None,
parent_parallel: Optional[GraphParallel] = None,
) -> Optional[GraphParallel]:
"""
Get current parallel
"""
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
current_parallel = parent_parallel_parent_parallel

if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()

for graph_edge in target_node_edges:
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
return current_parallel

@classmethod
def _check_exceed_parallel_limit(
Expand Down