From 43240fcd41e8705e9e4d79d77f94337730c2af40 Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Mon, 2 Sep 2024 14:49:30 +0800 Subject: [PATCH] fix --- .../workflow/hooks/use-nodes-interactions.ts | 40 ++++- .../components/workflow/hooks/use-workflow.ts | 53 +++--- web/app/components/workflow/index.tsx | 1 + web/app/components/workflow/utils.ts | 163 ++++++++++++++++++ 4 files changed, 217 insertions(+), 40 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index fabd96a3c3803b..efaf33ecc7a2b9 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -64,6 +64,7 @@ export const useNodesInteractions = () => { const { store: workflowHistoryStore } = useWorkflowHistoryStore() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { + checkNestedParallelLimit, getAfterNodesInSameBranch, } = useWorkflow() const { getNodesReadOnly } = useNodesReadOnly() @@ -372,14 +373,17 @@ export const useNodesInteractions = () => { } }) }) - setNodes(newNodes) const newEdges = produce(edges, (draft) => { draft.push(newEdge) }) - setEdges(newEdges) - handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + } }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) const handleNodeConnectStart = useCallback((_, { nodeId, handleType, handleId }) => { @@ -672,7 +676,7 @@ export const useNodesInteractions = () => { if (newIterationStartNode) draft.push(newIterationStartNode) }) - setNodes(newNodes) + if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { const { setShowAssignVariablePopup } = workflowStore.getState() @@ -696,7 +700,14 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } } if (!prevNodeId && nextNodeId) { const nextNodeIndex = nodes.findIndex(node => node.id === nextNodeId) @@ -775,7 +786,6 @@ export const useNodesInteractions = () => { if (newIterationStartNode) draft.push(newIterationStartNode) }) - setNodes(newNodes) if (newEdge) { const newEdges = produce(edges, (draft) => { draft.forEach((item) => { @@ -786,7 +796,21 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } + } + else { + if (checkNestedParallelLimit(newNodes, edges)) + setNodes(newNodes) + + else + return false } } if (prevNodeId && nextNodeId) { diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index f7ef4a807a4d1d..aa255263271bf7 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -29,8 +29,9 @@ import { useStore, useWorkflowStore, } from '../store' +import { getParallelInfo } from '../utils' import { - // PARALLEL_DEPTH_LIMIT, + PARALLEL_DEPTH_LIMIT, PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' @@ -293,40 +294,27 @@ export const useWorkflow = () => { setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) return false } - // if (sourceNodeOutgoers.length > 0) { - // let hasOverDepth = false - // let parallelDepth = 1 - // const traverse = (root: Node, depth: number) => { - // if (depth > PARALLEL_DEPTH_LIMIT) { - // hasOverDepth = true - // return - // } - // if (depth > parallelDepth) - // parallelDepth = depth - - // const incomerNodes = getIncomers(root, nodes, edges) - - // if (incomerNodes.length) { - // incomerNodes.forEach((incomer) => { - // const incomerOutgoers = getOutgoers(incomer, nodes, edges) - - // if (incomerOutgoers.length > 1) - // traverse(incomer, depth + 1) - // else - // traverse(incomer, depth) - // }) - // } - // } - // traverse(currentNode, parallelDepth) - // if (hasOverDepth) { - // const { setShowTips } = workflowStore.getState() - // setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) - // return false - // } - // } + return true }, [store, workflowStore, t]) + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { + const parallelList = getParallelInfo(nodes, edges, parentNodeId) + console.log(parallelList, 'parallelList') + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > PARALLEL_DEPTH_LIMIT) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) + return false + } + } + + return true + }, []) + const isValidConnection = useCallback(({ source, target }: Connection) => { const { edges, @@ -392,6 +380,7 @@ export const useWorkflow = () => { removeUsedVarInNodes, isNodeVarsUsedInNodes, checkParallelLimit, + checkNestedParallelLimit, isValidConnection, formatTimeFromNow, getNode, diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index cdccd60a3b5a16..dfda3358824153 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -421,6 +421,7 @@ const WorkflowWrap = memo(() => { citation: features.retriever_resource || { enabled: false }, moderation: features.sensitive_word_avoidance || { enabled: false }, } + // getParallelInfo(nodesData, edgesData) return ( diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 04dc1ac4101562..210662d350fd18 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -1,12 +1,15 @@ import { Position, getConnectedEdges, + getIncomers, getOutgoers, } from 'reactflow' import dagre from '@dagrejs/dagre' import { v4 as uuid4 } from 'uuid' import { cloneDeep, + groupBy, + isEqual, uniqBy, } from 'lodash-es' import type { @@ -589,3 +592,163 @@ export const variableTransformer = (v: ValueSelector | string) => { return `{{#${v.join('.')}#}}` } + +type ParallelInfoItem = { + parallelNodeId: string + depth: number + isBranch?: boolean +} +type NodeParallelInfo = { + parallelNodeId: string + edgeHandleId: string + depth: number +} +type NodeHandle = { + node: Node + handle: string +} +type NodeStreamInfo = { + upstreamNodes: Set + downstreamEdges: Set +} +export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { + let startNode + + if (parentNodeId) { + const parentNode = nodes.find(node => node.id === parentNodeId) + if (!parentNode) + throw new Error('Parent node not found') + + startNode = nodes.find(node => node.id === (parentNode.data as IterationNodeType).start_node_id) + } + else { + startNode = nodes.find(node => node.data.type === BlockEnum.Start) + } + if (!startNode) + throw new Error('Start node not found') + + const parallelList = [] as ParallelInfoItem[] + const nextNodeHandles = [{ node: startNode, handle: 'source' }] + + const traverse = (firstNodeHandle: NodeHandle) => { + const nodeEdgesSet = {} as Record> + const totalEdgesSet = new Set() + const nextHandles = [firstNodeHandle] + const streamInfo = {} as Record + const parallelListItem = { + parallelNodeId: '', + depth: 0, + } as ParallelInfoItem + const nodeParallelInfoMap = {} as Record + nodeParallelInfoMap[firstNodeHandle.node.id] = { + parallelNodeId: '', + edgeHandleId: '', + depth: 0, + } + + while (nextHandles.length) { + const currentNodeHandle = nextHandles.shift()! + const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle + const currentNodeHandleKey = currentNode.id + const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) + const connectedEdgesLength = connectedEdges.length + const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) + const incomers = getIncomers(currentNode, nodes, edges) + + if (!streamInfo[currentNodeHandleKey]) { + streamInfo[currentNodeHandleKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { + const newSet = new Set() + for (const item of totalEdgesSet) { + if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) + newSet.add(item) + } + if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + nextNodeHandles.push({ node: currentNode, handle: currentHandle }) + break + } + } + + if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + + outgoers.forEach((outgoer) => { + const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) + const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') + + Object.keys(sourceEdgesGroup).sort((a, b) => { + return sourceEdgesGroup[b].length - sourceEdgesGroup[a].length + }).forEach((sourceHandle) => { + nextHandles.push({ node: outgoer, handle: sourceHandle }) + }) + if (!outgoerConnectedEdges.length) + nextHandles.push({ node: outgoer, handle: 'source' }) + + const outgoerKey = outgoer.id + if (!nodeEdgesSet[outgoerKey]) + nodeEdgesSet[outgoerKey] = new Set() + + if (nodeEdgesSet[currentNodeHandleKey]) { + for (const item of nodeEdgesSet[currentNodeHandleKey]) + nodeEdgesSet[outgoerKey].add(item) + } + + if (!streamInfo[outgoerKey]) { + streamInfo[outgoerKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (!nodeParallelInfoMap[outgoer.id]) { + nodeParallelInfoMap[outgoer.id] = { + ...nodeParallelInfoMap[currentNode.id], + } + } + + if (connectedEdgesLength > 1) { + const edge = connectedEdges.find(edge => edge.target === outgoer.id)! + nodeEdgesSet[outgoerKey].add(edge.id) + totalEdgesSet.add(edge.id) + + streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) + streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) + + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[item].downstreamEdges.add(edge.id) + + if (!parallelListItem.parallelNodeId) + parallelListItem.parallelNodeId = currentNode.id + + const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 + const currentDepth = nodeParallelInfoMap[outgoer.id].depth + + nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) + } + else { + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[outgoerKey].upstreamNodes.add(item) + + nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth + } + }) + } + + parallelList.push(parallelListItem) + } + + while (nextNodeHandles.length) { + const nodeHandle = nextNodeHandles.shift()! + traverse(nodeHandle) + } + + console.log(parallelList, 'parallelList') + + return parallelList +}