Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/workflow-parallel-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Sep 2, 2024
2 parents 3a4faf9 + 0dabf79 commit 988e85a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 9 deletions.
11 changes: 10 additions & 1 deletion api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def init(cls,
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
Expand Down Expand Up @@ -310,6 +311,7 @@ def _check_connected_to_previous_node(
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str]) -> None:
Expand Down Expand Up @@ -365,6 +367,7 @@ def _recursively_add_parallels(cls,

in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
)

Expand Down Expand Up @@ -412,6 +415,7 @@ def _recursively_add_parallels(cls,
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
Expand Down Expand Up @@ -472,6 +476,7 @@ def _recursively_add_parallel_node_ids(cls,
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
Expand Down Expand Up @@ -499,7 +504,11 @@ def _fetch_all_node_ids_in_parallels(cls,
leaf_node_ids[branch_node_id].append(node_id)

for branch_node_id2, inner_route2 in routes_node_ids.items():
if branch_node_id != branch_node_id2 and node_id in inner_route2:
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []

Expand Down
2 changes: 2 additions & 0 deletions web/app/components/workflow/hooks/use-workflow-run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ export const useWorkflowRun = () => {
const currentIndex = draft.tracing!.findIndex((trace) => {
if (!trace.execution_metadata?.parallel_id)
return trace.node_id === data.node_id
if (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id)
return trace.node_id === data.node_id && trace.execution_metadata?.parallel_start_node_id === data.execution_metadata?.parallel_start_node_id

return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
})
Expand Down
2 changes: 2 additions & 0 deletions web/app/components/workflow/panel/debug-and-preview/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ export const useChat = (
const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
if (!item.execution_metadata?.parallel_id)
return item.node_id === data.node_id
if (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id)
return item.node_id === data.node_id && item.execution_metadata?.parallel_start_node_id === data.execution_metadata?.parallel_start_node_id

return item.node_id === data.node_id && item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
})
Expand Down
2 changes: 1 addition & 1 deletion web/app/components/workflow/run/node.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ const NodePanel: FC<Props> = ({
<div className='group transition-all bg-background-default border border-components-panel-border rounded-[10px] shadows-shadow-xs hover:shadow-md'>
<div
className={cn(
'flex items-center pl-1 pr-2 cursor-pointer',
'flex items-center pl-1 pr-3 cursor-pointer',
hideInfo ? 'py-2' : 'py-1.5',
!collapseState && (hideInfo ? '!pb-1' : '!pb-1.5'),
)}
Expand Down
21 changes: 14 additions & 7 deletions web/app/components/workflow/run/tracing-panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] {
const levelCounts: { [key: string]: number } = {}
const parallelChildCounts: { [key: string]: Set<string> } = {}

function getParallelTitle(parentId: string | null): string {
const getParallelTitle = (parentId: string | null): string => {
const levelKey = parentId || 'root'
if (!levelCounts[levelKey])
levelCounts[levelKey] = 0
Expand All @@ -50,7 +50,7 @@ function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] {
return `PARALLEL-${levelNumber}${letter}`
}

function getBranchTitle(parentId: string | null, branchNum: number): string {
const getBranchTitle = (parentId: string | null, branchNum: number): string => {
const levelKey = parentId || 'root'
const parentTitle = parentId ? parallelStacks[parentId].parallelTitle : ''
const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1
Expand Down Expand Up @@ -153,6 +153,7 @@ const TracingPanel: FC<TracingPanelProps> = ({ list, onShowIterationDetail }) =>
const newSet = new Set(prev)
if (newSet.has(id))
newSet.delete(id)

else
newSet.add(id)

Expand All @@ -165,12 +166,18 @@ const TracingPanel: FC<TracingPanelProps> = ({ list, onShowIterationDetail }) =>
}, [])

const handleParallelMouseLeave = useCallback((e: React.MouseEvent) => {
const relatedTarget = e.relatedTarget as HTMLElement
const closestParallel = relatedTarget?.closest('[data-parallel-id]')
if (closestParallel)
setHoveredParallel(closestParallel.getAttribute('data-parallel-id'))
else
const relatedTarget = e.relatedTarget as Element | null
if (relatedTarget && 'closest' in relatedTarget) {
const closestParallel = relatedTarget.closest('[data-parallel-id]')
if (closestParallel)
setHoveredParallel(closestParallel.getAttribute('data-parallel-id'))

else
setHoveredParallel(null)
}
else {
setHoveredParallel(null)
}
}, [])

const renderNode = (node: TracingNodeProps) => {
Expand Down

0 comments on commit 988e85a

Please sign in to comment.