diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 460e36ae60ab59..b201b28b88d14b 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -283,15 +283,12 @@ export const useWorkflow = () => { return isUsed }, [isVarUsedInNodes]) - const checkParallelLimit = useCallback((nodeId: string) => { + const checkParallelLimit = useCallback((nodeId: string, nodeHandle = 'source') => { const { - getNodes, edges, } = store.getState() - const nodes = getNodes() - const currentNode = nodes.find(node => node.id === nodeId)! - const sourceNodeOutgoers = getOutgoers(currentNode, nodes, edges) - if (sourceNodeOutgoers.length > PARALLEL_LIMIT - 1) { + const connectedEdges = edges.filter(edge => edge.source === nodeId && edge.sourceHandle === nodeHandle) + if (connectedEdges.length > PARALLEL_LIMIT - 1) { const { setShowTips } = workflowStore.getState() setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) return false @@ -322,7 +319,7 @@ export const useWorkflow = () => { return true }, [t, workflowStore]) - const isValidConnection = useCallback(({ source, target }: Connection) => { + const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { edges, getNodes, @@ -331,7 +328,7 @@ export const useWorkflow = () => { const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! - if (!checkParallelLimit(source!)) + if (!checkParallelLimit(source!, sourceHandle || 'source')) return false if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 6e3988eecb64b0..75694983cdcbd8 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -1,6 +1,7 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' import { @@ -10,6 +11,7 @@ import { useAvailableBlocks, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '@/app/components/workflow/hooks' import BlockSelector from '@/app/components/workflow/block-selector' import type { @@ -30,9 +32,11 @@ const Add = ({ isParallel, }: AddProps) => { const { t } = useTranslation() + const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { nodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(nodeData.type, nodeData.isInIteration) + const { checkParallelLimit } = useWorkflow() const handleSelect = useCallback((type, toolDefaultValue) => { handleNodeAdd( @@ -47,6 +51,13 @@ const Add = ({ ) }, [nodeId, sourceHandle, handleNodeAdd]) + const handleOpenChange = useCallback((newOpen: boolean) => { + if (newOpen && !checkParallelLimit(nodeId, sourceHandle)) + return + + setOpen(newOpen) + }, [checkParallelLimit, nodeId, sourceHandle]) + const renderTrigger = useCallback((open: boolean) => { return (
{ e.stopPropagation() - if (checkParallelLimit(id)) + if (checkParallelLimit(id, handleId)) setOpen(v => !v) - }, [checkParallelLimit, id]) + }, [checkParallelLimit, id, handleId]) const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { handleNodeAdd( {