Skip to content

Commit

Permalink
fix(workflow): in multi-parallel execution with multiple conditional …
Browse files Browse the repository at this point in the history
…branches (langgenius#8221)
  • Loading branch information
takatost authored and JunXu01 committed Nov 9, 2024
1 parent 6fdd0a0 commit cbdf6f7
Showing 1 changed file with 167 additions and 102 deletions.
269 changes: 167 additions & 102 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,132 +304,197 @@ def _recursively_add_parallels(
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
parallel_branch_node_ids = {}
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id)
if "default" not in parallel_branch_node_ids:
parallel_branch_node_ids["default"] = []

parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []

condition_edge_mappings[condition_hash].append(graph_edge)

for _, graph_edges in condition_edge_mappings.items():
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
if condition_hash not in parallel_branch_node_ids:
parallel_branch_node_ids[condition_hash] = []

for graph_edge in graph_edges:
parallel_branch_node_ids.append(graph_edge.target_node_id)
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)

condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping
parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None

parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=condition_parallel_branch_node_ids,
)

# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break

if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id

outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue

node_edges = edge_mapping.get(node_id)
if not node_edges:
continue

if len(node_edges) > 1:
continue

# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue

parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue

if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)

if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()

if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
parallel_mapping[parallel.id] = parallel

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)

# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break

if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id

outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue

node_edges = edge_mapping.get(node_id)
if not node_edges:
continue

if len(node_edges) > 1:
continue

target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue

if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue

if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: Optional[GraphParallel] = None,
parent_parallel: Optional[GraphParallel] = None,
) -> Optional[GraphParallel]:
"""
Get current parallel
"""
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
current_parallel = parent_parallel_parent_parallel

if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()

for graph_edge in target_node_edges:
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel

cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
return current_parallel

@classmethod
def _check_exceed_parallel_limit(
Expand Down

0 comments on commit cbdf6f7

Please sign in to comment.