Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/workflow-parallel-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Aug 30, 2024
2 parents 2ddcd37 + ee1587c commit d13f1fb
Show file tree
Hide file tree
Showing 12 changed files with 403 additions and 71 deletions.
21 changes: 21 additions & 0 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
Expand Down Expand Up @@ -280,6 +283,24 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
Expand Down
8 changes: 8 additions & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class Data(BaseModel):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None

event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
Expand All @@ -243,6 +245,8 @@ def to_ignore_detail_dict(self):
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
}
}

Expand Down Expand Up @@ -274,6 +278,8 @@ class Data(BaseModel):
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None

event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
Expand Down Expand Up @@ -303,6 +309,8 @@ def to_ignore_detail_dict(self):
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
}
}

Expand Down
4 changes: 4 additions & 0 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def _workflow_node_start_to_stream_response(
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
)

Expand Down Expand Up @@ -444,6 +446,8 @@ def _workflow_node_finish_to_stream_response(
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
),
)

Expand Down
110 changes: 98 additions & 12 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def init(cls,
node_parallel_mapping=node_parallel_mapping
)

# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
)

# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
Expand Down Expand Up @@ -315,11 +324,11 @@ def _recursively_add_parallels(cls,
target_node_edges = edge_mapping.get(start_node_id, [])
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_node_ids = []
parallel_branch_node_ids = []
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_node_ids.append(graph_edge.target_node_id)
parallel_branch_node_ids.append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
Expand All @@ -330,13 +339,13 @@ def _recursively_add_parallels(cls,
for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_node_ids.append(graph_edge.target_node_id)
parallel_branch_node_ids.append(graph_edge.target_node_id)

# any target node id in node_parallel_mapping
if parallel_node_ids:
# all parallel_node_ids in node_parallel_mapping
if parallel_branch_node_ids:
# all parallel_branch_node_ids in node_parallel_mapping
parent_parallel_id = None
for node_id in parallel_node_ids:
for node_id in parallel_branch_node_ids:
if node_id in node_parallel_mapping:
parent_parallel_id = node_parallel_mapping[node_id]
break
Expand All @@ -356,7 +365,7 @@ def _recursively_add_parallels(cls,

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
parallel_node_ids=parallel_node_ids
parallel_branch_node_ids=parallel_branch_node_ids
)

# collect all branches node ids
Expand Down Expand Up @@ -408,6 +417,33 @@ def _recursively_add_parallels(cls,
node_parallel_mapping=node_parallel_mapping
)

@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return

current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")

if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
)

@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
Expand Down Expand Up @@ -436,19 +472,19 @@ def _recursively_add_parallel_node_ids(cls,
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
parallel_node_ids: list[str]) -> dict[str, list[str]]:
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_node_id in parallel_node_ids:
routes_node_ids[parallel_node_id] = [parallel_node_id]
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]

# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_node_id,
routes_node_ids=routes_node_ids[parallel_node_id]
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id]
)

# fetch leaf node ids from routes node ids
Expand All @@ -472,6 +508,30 @@ def _fetch_all_node_ids_in_parallels(cls,
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))

duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids

for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]

branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
Expand Down Expand Up @@ -526,3 +586,29 @@ def _recursively_fetch_routes(cls,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
)

@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False

for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True

if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
):
return True

return False
10 changes: 7 additions & 3 deletions api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,13 @@ def test_parallels_graph():

assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
start_edges = graph.edge_mapping.get("start")
assert start_edges is not None
assert start_edges[i].target_node_id == f"llm{i+1}"

llm_edges = graph.edge_mapping.get(f"llm{i+1}")
assert llm_edges is not None
assert llm_edges[0].target_node_id == "answer"

assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
Expand Down
17 changes: 5 additions & 12 deletions web/app/components/base/chat/chat/answer/workflow-process.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import {
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import type { ChatItem, WorkflowProcess } from '../../types'
import TracingPanel from '@/app/components/workflow/run/tracing-panel'
import cn from '@/utils/classnames'
import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general'
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
import NodePanel from '@/app/components/workflow/run/node'
import { useStore as useAppStore } from '@/app/components/app/store'

type WorkflowProcessProps = {
Expand All @@ -31,7 +31,6 @@ const WorkflowProcessItem = ({
grayBg,
expand = false,
hideInfo = false,
hideProcessDetail = false,
}: WorkflowProcessProps) => {
const { t } = useTranslation()
const [collapse, setCollapse] = useState(!expand)
Expand Down Expand Up @@ -107,16 +106,10 @@ const WorkflowProcessItem = ({
!collapse && (
<div className='mt-1.5'>
{
data.tracing.map(node => (
<div key={node.id} className='mb-1 last-of-type:mb-0'>
<NodePanel
nodeInfo={node}
hideInfo={hideInfo}
hideProcessDetail={hideProcessDetail}
onShowIterationDetail={showIterationDetail}
/>
</div>
))
<TracingPanel
list={data.tracing}
onShowIterationDetail={showIterationDetail}
/>
}
</div>
)
Expand Down
7 changes: 6 additions & 1 deletion web/app/components/workflow/hooks/use-workflow-run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,12 @@ export const useWorkflowRun = () => {
else {
const nodes = getNodes()
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id)
const currentIndex = draft.tracing!.findIndex((trace) => {
if (!trace.execution_metadata?.parallel_id)
return trace.node_id === data.node_id

return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
})

if (currentIndex > -1 && draft.tracing) {
draft.tracing[currentIndex] = {
Expand Down
7 changes: 6 additions & 1 deletion web/app/components/workflow/panel/debug-and-preview/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,12 @@ export const useChat = (
}))
}
else {
const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
if (!item.execution_metadata?.parallel_id)
return item.node_id === data.node_id

return item.node_id === data.node_id && item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
})
responseItem.workflowProcess!.tracing[currentIndex] = {
...(responseItem.workflowProcess!.tracing[currentIndex].extras
? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras }
Expand Down
Loading

0 comments on commit d13f1fb

Please sign in to comment.