Skip to content

Commit

Permalink
fix(workflow): bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Sep 2, 2024
1 parent 43240fc commit 5bda3a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 28 deletions.
75 changes: 48 additions & 27 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,21 +309,26 @@ def _check_connected_to_previous_node(
)

@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str]) -> None:
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
Expand All @@ -345,18 +350,7 @@ def _recursively_add_parallels(cls,

# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
# all parallel_branch_node_ids in node_parallel_mapping
parent_parallel_id = None
for node_id in parallel_branch_node_ids:
if node_id in node_parallel_mapping:
parent_parallel_id = node_parallel_mapping[node_id]
break

parent_parallel = None
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
raise Exception(f"Parent parallel {parent_parallel_id} not found")
parent_parallel_id = parent_parallel.id if parent_parallel else None

parallel = GraphParallel(
start_from_node_id=start_node_id,
Expand All @@ -375,8 +369,17 @@ def _recursively_add_parallels(cls,
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break

if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id

outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
Expand Down Expand Up @@ -418,7 +421,8 @@ def _recursively_add_parallels(cls,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
node_parallel_mapping=node_parallel_mapping,
parent_parallel=parallel if parallel else parent_parallel
)

@classmethod
Expand Down Expand Up @@ -538,14 +542,14 @@ def _fetch_all_node_ids_in_parallels(cls,
edge_mapping=edge_mapping
):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
del merge_branch_node_ids[node_id]

branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
Expand Down Expand Up @@ -613,13 +617,30 @@ def _is_node_in_routes(cls,
if start_node_id not in reverse_edge_mapping:
return False

all_routes_node_ids = []
for _, node_ids in routes_node_ids.items():
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.append(node_id)
all_routes_node_ids.add(node_id)

if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []

parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)

parallel_start_node_id = None
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True

if not parallel_start_node_id:
raise Exception("Parallel start node id not found")

for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids:
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
return False

return True
Expand Down
8 changes: 7 additions & 1 deletion api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,13 @@ def _run_parallel_branches(
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found.')

node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found.')

parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
Expand Down
5 changes: 5 additions & 0 deletions api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,11 @@ def test_parallels_graph6():
"source": "code3",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
Expand Down

0 comments on commit 5bda3a3

Please sign in to comment.