diff --git a/web/app/components/workflow/hooks/use-available-blocks.ts b/web/app/components/workflow/hooks/use-available-blocks.ts index a5c7a529c9..b4e037d29f 100644 --- a/web/app/components/workflow/hooks/use-available-blocks.ts +++ b/web/app/components/workflow/hooks/use-available-blocks.ts @@ -6,7 +6,7 @@ import { BlockEnum } from '../types' import { useNodesMetaData } from './use-nodes-meta-data' const availableBlocksFilter = (nodeType: BlockEnum, inContainer?: boolean) => { - if (inContainer && (nodeType === BlockEnum.Iteration || nodeType === BlockEnum.Loop || nodeType === BlockEnum.End)) + if (inContainer && (nodeType === BlockEnum.Iteration || nodeType === BlockEnum.Loop || nodeType === BlockEnum.End || nodeType === BlockEnum.DataSource || nodeType === BlockEnum.KnowledgeBase)) return false if (!inContainer && nodeType === BlockEnum.LoopEnd) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index e5112ac40f..3e59c25cc4 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -414,7 +414,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, targetNode)) { setNodes(newNodes) setEdges(newEdges) @@ -819,7 +819,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, prevNode)) { setNodes(newNodes) setEdges(newEdges) } @@ -939,7 +939,7 @@ export const useNodesInteractions = () => { draft.push(newEdge) }) - if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + if (checkNestedParallelLimit(newNodes, newEdges, nextNode)) { setNodes(newNodes) setEdges(newEdges) } @@ -1234,13 +1234,13 @@ export const useNodesInteractions = () => { if (nodeId) { // If nodeId is provided, copy that specific node const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start - && node.type !== CUSTOM_ITERATION_START_NODE && node.type !== CUSTOM_LOOP_START_NODE && node.data.type !== BlockEnum.LoopEnd) + && node.type !== CUSTOM_ITERATION_START_NODE && node.type !== CUSTOM_LOOP_START_NODE && node.data.type !== BlockEnum.LoopEnd && node.data.type !== BlockEnum.KnowledgeBase) if (nodeToCopy) setClipboardElements([nodeToCopy]) } else { // If no nodeId is provided, fall back to the current behavior - const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start + const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start && node.data.type !== BlockEnum.DataSource && !node.data.isInIteration && !node.data.isInLoop) if (bundledNodes.length) { @@ -1248,7 +1248,7 @@ export const useNodesInteractions = () => { return } - const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start && node.data.type !== BlockEnum.LoopEnd) + const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start && node.data.type !== BlockEnum.LoopEnd && node.data.type !== BlockEnum.DataSource) if (selectedNode) setClipboardElements([selectedNode]) diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 6587b44fa3..e24a5f5b21 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -12,23 +12,26 @@ import type { Connection, } from 'reactflow' import type { + BlockEnum, Edge, Node, ValueSelector, } from '../types' import { - BlockEnum, WorkflowRunningStatus, } from '../types' import { useStore, useWorkflowStore, } from '../store' - +import { getParallelInfo } from '../utils' import { + PARALLEL_DEPTH_LIMIT, PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' +import type { IterationNodeType } from '../nodes/iteration/types' +import type { LoopNodeType } from '../nodes/loop/types' import { CUSTOM_NOTE_NODE } from '../note-node/constants' import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils' import { useAvailableBlocks } from './use-available-blocks' @@ -41,6 +44,7 @@ import { import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' import { CUSTOM_LOOP_START_NODE } from '@/app/components/workflow/nodes/loop-start/constants' import { basePath } from '@/utils/var' +import { useNodesMetaData } from '.' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -53,6 +57,7 @@ export const useWorkflow = () => { const store = useStoreApi() const workflowStore = useWorkflowStore() const { getAvailableBlocks } = useAvailableBlocks() + const { nodesMap } = useNodesMetaData() const setPanelWidth = useCallback((width: number) => { localStorage.setItem('workflow-node-panel-width', `${width}`) workflowStore.setState({ panelWidth: width }) @@ -64,13 +69,17 @@ export const useWorkflow = () => { edges, } = store.getState() const nodes = getNodes() - let startNode = nodes.find(node => node.data.type === BlockEnum.Start) const currentNode = nodes.find(node => node.id === nodeId) - if (currentNode?.parentId) - startNode = nodes.find(node => node.parentId === currentNode.parentId && (node.type === CUSTOM_ITERATION_START_NODE || node.type === CUSTOM_LOOP_START_NODE)) + let startNodes = nodes.filter(node => nodesMap?.[node.data.type as BlockEnum]?.metaData.isStart) || [] - if (!startNode) + if (currentNode?.parentId) { + const startNode = nodes.find(node => node.parentId === currentNode.parentId && (node.type === CUSTOM_ITERATION_START_NODE || node.type === CUSTOM_LOOP_START_NODE)) + if (startNode) + startNodes = [startNode] + } + + if (!startNodes.length) return [] const list: Node[] = [] @@ -89,8 +98,10 @@ export const useWorkflow = () => { callback(root) } } - preOrder(startNode, (node) => { - list.push(node) + startNodes.forEach((startNode) => { + preOrder(startNode, (node) => { + list.push(node) + }) }) const incomers = getIncomers({ id: nodeId } as Node, nodes, edges) @@ -100,7 +111,7 @@ export const useWorkflow = () => { return uniqBy(list, 'id').filter((item: Node) => { return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type) }) - }, [store]) + }, [store, nodesMap]) const getBeforeNodesInSameBranch = useCallback((nodeId: string, newNodes?: Node[], newEdges?: Edge[]) => { const { @@ -227,33 +238,6 @@ export const useWorkflow = () => { return nodes.filter(node => node.parentId === nodeId) }, [store]) - const isFromStartNode = useCallback((nodeId: string) => { - const { getNodes } = store.getState() - const nodes = getNodes() - const currentNode = nodes.find(node => node.id === nodeId) - - if (!currentNode) - return false - - if (currentNode.data.type === BlockEnum.Start) - return true - - const checkPreviousNodes = (node: Node) => { - const previousNodes = getBeforeNodeById(node.id) - - for (const prevNode of previousNodes) { - if (prevNode.data.type === BlockEnum.Start) - return true - if (checkPreviousNodes(prevNode)) - return true - } - - return false - } - - return checkPreviousNodes(currentNode) - }, [store, getBeforeNodeById]) - const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => { const { getNodes, setNodes } = store.getState() const afterNodes = getAfterNodesInSameBranch(nodeId) @@ -316,28 +300,96 @@ export const useWorkflow = () => { return true }, [store, workflowStore, t]) - const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { - // const { - // parallelList, - // hasAbnormalEdges, - // } = getParallelInfo(nodes, edges, parentNodeId) - // const { workflowConfig } = workflowStore.getState() + const getRootNodesById = useCallback((nodeId: string) => { + const { + getNodes, + edges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId) - // if (hasAbnormalEdges) - // return false + const rootNodes: Node[] = [] - // for (let i = 0; i < parallelList.length; i++) { - // const parallel = parallelList[i] + if (!currentNode) + return rootNodes - // if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) { - // const { setShowTips } = workflowStore.getState() - // setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) })) - // return false - // } - // } + if (currentNode.parentId) { + const parentNode = nodes.find(node => node.id === currentNode.parentId) + if (parentNode) { + const parentList = getRootNodesById(parentNode.id) + + rootNodes.push(...parentList) + } + } + + const traverse = (root: Node, callback: (node: Node) => void) => { + if (root) { + const incomers = getIncomers(root, nodes, edges) + + if (incomers.length) { + incomers.forEach((node) => { + traverse(node, callback) + }) + } + else { + callback(root) + } + } + } + traverse(currentNode, (node) => { + rootNodes.push(node) + }) + + const length = rootNodes.length + if (length) + return uniqBy(rootNodes, 'id') + + return [] + }, [store]) + + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], targetNode?: Node) => { + const { id, parentId } = targetNode || {} + let startNodes: Node[] = [] + + if (parentId) { + const parentNode = nodes.find(node => node.id === parentId) + if (!parentNode) + throw new Error('Parent node not found') + + const startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id) + if (startNode) + startNodes = [startNode] + } + else { + startNodes = nodes.filter(node => nodesMap?.[node.data.type as BlockEnum]?.metaData.isStart) || [] + } + + if (!startNodes.length) + startNodes = getRootNodesById(id || '') + + for (let i = 0; i < startNodes.length; i++) { + const { + parallelList, + hasAbnormalEdges, + } = getParallelInfo(startNodes[i], nodes, edges) + const { workflowConfig } = workflowStore.getState() + + if (hasAbnormalEdges) + return false + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) })) + return false + } + } + } return true - }, [t, workflowStore]) + }, [t, workflowStore, nodesMap, getRootNodesById]) const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { @@ -385,13 +437,6 @@ export const useWorkflow = () => { return !hasCycle(targetNode) }, [store, checkParallelLimit, getAvailableBlocks]) - const getNode = useCallback((nodeId?: string) => { - const { getNodes } = store.getState() - const nodes = getNodes() - - return nodes.find(node => node.id === nodeId) || nodes.find(node => node.data.type === BlockEnum.Start) - }, [store]) - return { setPanelWidth, getTreeLeafNodes, @@ -405,11 +450,10 @@ export const useWorkflow = () => { checkParallelLimit, checkNestedParallelLimit, isValidConnection, - isFromStartNode, - getNode, getBeforeNodeById, getIterationNodeChildren, getLoopNodeChildren, + getRootNodesById, } } diff --git a/web/app/components/workflow/utils/workflow.ts b/web/app/components/workflow/utils/workflow.ts index 88c31f09b5..f81b255b78 100644 --- a/web/app/components/workflow/utils/workflow.ts +++ b/web/app/components/workflow/utils/workflow.ts @@ -16,8 +16,6 @@ import type { import { BlockEnum, } from '../types' -import type { IterationNodeType } from '../nodes/iteration/types' -import type { LoopNodeType } from '../nodes/loop/types' export const canRunBySingle = (nodeType: BlockEnum) => { return nodeType === BlockEnum.LLM @@ -178,19 +176,7 @@ type NodeStreamInfo = { upstreamNodes: Set downstreamEdges: Set } -export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { - let startNode - - if (parentNodeId) { - const parentNode = nodes.find(node => node.id === parentNodeId) - if (!parentNode) - throw new Error('Parent node not found') - - startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id) - } - else { - startNode = nodes.find(node => node.data.type === BlockEnum.Start) - } +export const getParallelInfo = (startNode: Node, nodes: Node[], edges: Edge[]) => { if (!startNode) throw new Error('Start node not found')