Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat/workflow-parallel-support' …
Browse files Browse the repository at this point in the history
…into feat/workflow-parallel-support
  • Loading branch information
takatost committed Sep 2, 2024
2 parents bbc922d + 7035f64 commit 35d9c59
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 158 deletions.
41 changes: 39 additions & 2 deletions web/app/components/workflow/hooks/use-nodes-interactions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ export const useNodesInteractions = () => {
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.NodeConnect)
}
}, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory])
}, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory, checkNestedParallelLimit])

const handleNodeConnectStart = useCallback<OnConnectStart>((_, { nodeId, handleType, handleId }) => {
if (getNodesReadOnly())
Expand Down Expand Up @@ -930,7 +930,7 @@ export const useNodesInteractions = () => {
}
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.NodeAdd)
}, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch])
}, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit])

const handleNodeChange = useCallback((
currentNodeId: string,
Expand Down Expand Up @@ -1254,6 +1254,42 @@ export const useNodesInteractions = () => {
saveStateToHistory(WorkflowHistoryEvent.NodeResize)
}, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory])

const handleNodeDisconnect = useCallback((nodeId: string) => {
if (getNodesReadOnly())
return

const {
getNodes,
setNodes,
edges,
setEdges,
} = store.getState()
const nodes = getNodes()
const currentNode = nodes.find(node => node.id === nodeId)!
const connectedEdges = getConnectedEdges([currentNode], edges)
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
connectedEdges.map(edge => ({ type: 'remove', edge })),
nodes,
)
const newNodes = produce(nodes, (draft: Node[]) => {
draft.forEach((node) => {
if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) {
node.data = {
...node.data,
...nodesConnectedSourceOrTargetHandleIdsMap[node.id],
}
}
})
})
setNodes(newNodes)
const newEdges = produce(edges, (draft) => {
return draft.filter(edge => !connectedEdges.find(connectedEdge => connectedEdge.id === edge.id))
})
setEdges(newEdges)
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.EdgeDelete)
}, [store, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory])

const handleHistoryBack = useCallback(() => {
if (getNodesReadOnly() || getWorkflowReadOnly())
return
Expand Down Expand Up @@ -1306,6 +1342,7 @@ export const useNodesInteractions = () => {
handleNodesDuplicate,
handleNodesDelete,
handleNodeResize,
handleNodeDisconnect,
handleHistoryBack,
handleHistoryForward,
}
Expand Down
3 changes: 1 addition & 2 deletions web/app/components/workflow/hooks/use-workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ export const useWorkflow = () => {

const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => {
const parallelList = getParallelInfo(nodes, edges, parentNodeId)
console.log(parallelList, 'parallelList')

for (let i = 0; i < parallelList.length; i++) {
const parallel = parallelList[i]
Expand All @@ -313,7 +312,7 @@ export const useWorkflow = () => {
}

return true
}, [])
}, [t, workflowStore])

const isValidConnection = useCallback(({ source, target }: Connection) => {
const {
Expand Down
1 change: 0 additions & 1 deletion web/app/components/workflow/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,6 @@ const WorkflowWrap = memo(() => {
citation: features.retriever_resource || { enabled: false },
moderation: features.sensitive_word_avoidance || { enabled: false },
}
// getParallelInfo(nodesData, edgesData)

return (
<ReactFlowProvider>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ type AddProps = {
nodeId: string
nodeData: CommonNodeType
sourceHandle: string
branchName?: string
isParallel?: boolean
}
const Add = ({
nodeId,
nodeData,
sourceHandle,
branchName,
isParallel,
}: AddProps) => {
const { t } = useTranslation()
const { handleNodeAdd } = useNodesInteractions()
Expand Down Expand Up @@ -57,23 +57,19 @@ const Add = ({
${nodesReadOnly && '!cursor-not-allowed'}
`}
>
{
branchName && (
<div
className='absolute left-1 right-1 -top-[7.5px] flex items-center h-3 text-[10px] text-text-placeholder font-semibold'
title={branchName.toLocaleUpperCase()}
>
<div className='inline-block px-0.5 rounded-[5px] bg-background-default truncate'>{branchName.toLocaleUpperCase()}</div>
</div>
)
}
<div className='flex items-center justify-center mr-1.5 w-5 h-5 rounded-[5px] bg-background-default-dimm'>
<RiAddLine className='w-3 h-3' />
</div>
{t('workflow.panel.selectNextStep')}
<div className='flex items-center uppercase'>
{
isParallel
? t('workflow.common.addParallelNode')
: t('workflow.panel.selectNextStep')
}
</div>
</div>
)
}, [branchName, t, nodesReadOnly])
}, [t, nodesReadOnly, isParallel])

return (
<BlockSelector
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import Add from './add'
import Item from './item'
import type {
CommonNodeType,
Node,
} from '@/app/components/workflow/types'

type ContainerProps = {
nodeId: string
nodeData: CommonNodeType
sourceHandle: string
nextNodes: Node[]
branchName?: string
}

const Container = ({
nodeId,
nodeData,
sourceHandle,
nextNodes,
branchName,
}: ContainerProps) => {
return (
<div className='p-0.5 space-y-0.5 rounded-[10px] bg-background-section-burn'>
{
branchName && (
<div
className='flex items-center px-2 system-2xs-semibold-uppercase text-text-tertiary truncate'
title={branchName}
>
{branchName}
</div>
)
}
{
nextNodes.map(nextNode => (
<Item
key={nextNode.id}
nodeId={nextNode.id}
data={nextNode.data}
sourceHandle='source'
/>
))
}
<Add
isParallel={!!nextNodes.length}
nodeId={nodeId}
nodeData={nodeData}
sourceHandle={sourceHandle}
/>
</div>
)
}

export default Container
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { memo } from 'react'
import { memo, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import {
getConnectedEdges,
getOutgoers,
Expand All @@ -8,29 +9,45 @@ import {
import { useToolIcon } from '../../../../hooks'
import BlockIcon from '../../../../block-icon'
import type {
Branch,
Node,
} from '../../../../types'
import { BlockEnum } from '../../../../types'
import Add from './add'
import Item from './item'
import Line from './line'
import Container from './container'

type NextStepProps = {
selectedNode: Node
}
const NextStep = ({
selectedNode,
}: NextStepProps) => {
const { t } = useTranslation()
const data = selectedNode.data
const toolIcon = useToolIcon(data)
const store = useStoreApi()
const branches = data._targetBranches || []
const branches = useMemo(() => {
return data._targetBranches || []
}, [data])
const nodeWithBranches = data.type === BlockEnum.IfElse || data.type === BlockEnum.QuestionClassifier
const edges = useEdges()
const outgoers = getOutgoers(selectedNode as Node, store.getState().getNodes(), edges)
const connectedEdges = getConnectedEdges([selectedNode] as Node[], edges).filter(edge => edge.source === selectedNode!.id)

const branchesOutgoers = useMemo(() => {
if (!branches?.length)
return []

return branches.map((branch) => {
const connected = connectedEdges.filter(edge => edge.sourceHandle === branch.id)
const nextNodes = connected.map(edge => outgoers.find(outgoer => outgoer.id === edge.target)!)

return {
branch,
nextNodes,
}
})
}, [branches, connectedEdges, outgoers])

return (
<div className='flex py-1'>
<div className='shrink-0 relative flex items-center justify-center w-9 h-9 bg-background-default rounded-lg border-[0.5px] border-divider-regular shadow-xs'>
Expand All @@ -39,59 +56,32 @@ const NextStep = ({
toolIcon={toolIcon}
/>
</div>
<Line linesNumber={nodeWithBranches ? branches.length : 1} />
<div className='grow'>
<Line
list={nodeWithBranches ? branchesOutgoers.map(item => item.nextNodes.length + 1) : [1]}
/>
<div className='grow space-y-2'>
{
!nodeWithBranches && !!outgoers.length && (
<Item
nodeId={outgoers[0].id}
data={outgoers[0].data}
sourceHandle='source'
/>
)
}
{
!nodeWithBranches && !outgoers.length && (
<Add
!nodeWithBranches && (
<Container
nodeId={selectedNode!.id}
nodeData={selectedNode!.data}
sourceHandle='source'
nextNodes={outgoers}
/>
)
}
{
!!branches?.length && nodeWithBranches && (
branches.map((branch: Branch) => {
const connected = connectedEdges.find(edge => edge.sourceHandle === branch.id)
const target = outgoers.find(outgoer => outgoer.id === connected?.target)

nodeWithBranches && (
branchesOutgoers.map((item, index) => {
return (
<div
key={branch.id}
className='mb-3 last-of-type:mb-0'
>
{
connected && (
<Item
data={target!.data!}
nodeId={target!.id}
sourceHandle={branch.id}
branchName={branch.name}
/>
)
}
{
!connected && (
<Add
key={branch.id}
nodeId={selectedNode!.id}
nodeData={selectedNode!.data}
sourceHandle={branch.id}
branchName={branch.name}
/>
)
}
</div>
<Container
key={item.branch.id}
nodeId={selectedNode!.id}
nodeData={selectedNode!.data}
sourceHandle={item.branch.id}
nextNodes={item.nextNodes}
branchName={item.branch.name || `${t('workflow.nodes.questionClassifiers.class')} ${index + 1}`}
/>
)
})
)
Expand Down
Loading

0 comments on commit 35d9c59

Please sign in to comment.