From 0d01025254688480d4aea97a50cdec7454ddc89b Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 21 May 2025 16:34:41 +0800 Subject: [PATCH] parallel check --- .../workflow/hooks/use-nodes-interactions.ts | 6 +- .../components/workflow/hooks/use-workflow.ts | 108 +++++++++++++++--- web/app/components/workflow/utils/workflow.ts | 16 +-- 3 files changed, 94 insertions(+), 36 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 56ac59cbb0..3e59c25cc4 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -414,7 +414,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, targetNode)) { setNodes(newNodes) setEdges(newEdges) @@ -819,7 +819,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, prevNode)) { setNodes(newNodes) setEdges(newEdges) } @@ -939,7 +939,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, nextNode)) { setNodes(newNodes) setEdges(newEdges) } diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index dbbe4d200c..e24a5f5b21 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -24,11 +24,14 @@ import { useStore, useWorkflowStore, } from '../store' - +import { getParallelInfo } from '../utils' import { + PARALLEL_DEPTH_LIMIT, PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' +import type { IterationNodeType } from '../nodes/iteration/types' +import type { LoopNodeType } from '../nodes/loop/types' import { CUSTOM_NOTE_NODE } from '../note-node/constants' import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils' import { useAvailableBlocks } from './use-available-blocks' @@ -297,28 +300,96 @@ export const useWorkflow = () => { return true }, [store, workflowStore, t]) - const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { - // const { - // parallelList, - // hasAbnormalEdges, - // } = getParallelInfo(nodes, edges, parentNodeId) - // const { workflowConfig } = workflowStore.getState() + const getRootNodesById = useCallback((nodeId: string) => { + const { + getNodes, + edges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId) - // if (hasAbnormalEdges) - // return false + const rootNodes: Node[] = [] - // for (let i = 0; i < parallelList.length; i++) { - // const parallel = parallelList[i] + if (!currentNode) + return rootNodes - // if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) { - // const { setShowTips } = workflowStore.getState() - // setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) })) - // return false - // } - // } + if (currentNode.parentId) { + const parentNode = nodes.find(node => node.id === currentNode.parentId) + if (parentNode) { + const parentList = getRootNodesById(parentNode.id) + + rootNodes.push(...parentList) + } + } + + const traverse = (root: Node, callback: (node: Node) => void) => { + if (root) { + const incomers = getIncomers(root, nodes, edges) + + if (incomers.length) { + incomers.forEach((node) => { + traverse(node, callback) + }) + } + else { + callback(root) + } + } + } + traverse(currentNode, (node) => { + rootNodes.push(node) + }) + + const length = rootNodes.length + if (length) + return uniqBy(rootNodes, 'id') + + return [] + }, [store]) + + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], targetNode?: Node) => { + const { id, parentId } = targetNode || {} + let startNodes: Node[] = [] + + if (parentId) { + const parentNode = nodes.find(node => node.id === parentId) + if (!parentNode) + throw new Error('Parent node not found') + + const startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id) + if (startNode) + startNodes = [startNode] + } + else { + startNodes = nodes.filter(node => nodesMap?.[node.data.type as BlockEnum]?.metaData.isStart) || [] + } + + if (!startNodes.length) + startNodes = getRootNodesById(id || '') + + for (let i = 0; i < startNodes.length; i++) { + const { + parallelList, + hasAbnormalEdges, + } = getParallelInfo(startNodes[i], nodes, edges) + const { workflowConfig } = workflowStore.getState() + + if (hasAbnormalEdges) + return false + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) })) + return false + } + } + } return true - }, [t, workflowStore]) + }, [t, workflowStore, nodesMap, getRootNodesById]) const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { @@ -382,6 +453,7 @@ export const useWorkflow = () => { getBeforeNodeById, getIterationNodeChildren, getLoopNodeChildren, + getRootNodesById, } } diff --git a/web/app/components/workflow/utils/workflow.ts b/web/app/components/workflow/utils/workflow.ts index 88c31f09b5..f81b255b78 100644 --- a/web/app/components/workflow/utils/workflow.ts +++ b/web/app/components/workflow/utils/workflow.ts @@ -16,8 +16,6 @@ import type { import { BlockEnum, } from '../types' -import type { IterationNodeType } from '../nodes/iteration/types' -import type { LoopNodeType } from '../nodes/loop/types' export const canRunBySingle = (nodeType: BlockEnum) => { return nodeType === BlockEnum.LLM @@ -178,19 +176,7 @@ type NodeStreamInfo = { upstreamNodes: Set downstreamEdges: Set } -export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { - let startNode - - if (parentNodeId) { - const parentNode = nodes.find(node => node.id === parentNodeId) - if (!parentNode) - throw new Error('Parent node not found') - - startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id) - } - else { - startNode = nodes.find(node => node.data.type === BlockEnum.Start) - } +export const getParallelInfo = (startNode: Node, nodes: Node[], edges: Edge[]) => { if (!startNode) throw new Error('Start node not found')