From a311f88c99883b2f7f5c8b08bbee719251d74c6c Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Tue, 27 Feb 2024 18:02:29 +0800 Subject: [PATCH] compute node position --- web/app/(commonLayout)/workflow/page.tsx | 21 ++++--- .../workflow/block-selector/index.tsx | 6 +- web/app/components/workflow/constants.ts | 10 ++++ web/app/components/workflow/hooks.ts | 48 ++++++++++++---- web/app/components/workflow/index.tsx | 53 ++++++++++++++++-- .../nodes/_base/components/node-handle.tsx | 19 ++++--- .../components/workflow/nodes/_base/node.tsx | 4 +- web/app/components/workflow/types.ts | 21 ++++++- web/app/components/workflow/utils.ts | 56 +++++++++++++++++++ 9 files changed, 203 insertions(+), 35 deletions(-) diff --git a/web/app/(commonLayout)/workflow/page.tsx b/web/app/(commonLayout)/workflow/page.tsx index 486978ebf2..96e08eab68 100644 --- a/web/app/(commonLayout)/workflow/page.tsx +++ b/web/app/(commonLayout)/workflow/page.tsx @@ -8,7 +8,7 @@ const initialNodes = [ { id: '1', type: 'custom', - position: { x: 130, y: 130 }, + // position: { x: 130, y: 130 }, data: { type: 'start' }, }, { @@ -21,20 +21,20 @@ const initialNodes = [ id: '3', type: 'custom', position: { x: 738, y: 130 }, - data: { type: 'llm' }, + data: { type: 'llm', sortIndexInBranches: 0 }, }, { id: '4', type: 'custom', position: { x: 738, y: 330 }, - data: { type: 'llm' }, - }, - { - id: '5', - type: 'custom', - position: { x: 1100, y: 130 }, - data: { type: 'llm' }, + data: { type: 'llm', sortIndexInBranches: 1 }, }, + // { + // id: '5', + // type: 'custom', + // position: { x: 1100, y: 130 }, + // data: { type: 'llm' }, + // }, ] const initialEdges = [ @@ -44,6 +44,7 @@ const initialEdges = [ source: '1', sourceHandle: 'source', target: '2', + targetHandle: 'target', }, { id: '1', @@ -51,6 +52,7 @@ const initialEdges = [ source: '2', sourceHandle: 'condition1', target: '3', + targetHandle: 'target', }, { id: '2', @@ -58,6 +60,7 @@ const initialEdges = [ source: '2', sourceHandle: 'condition2', target: '4', + targetHandle: 'target', }, ] diff --git a/web/app/components/workflow/block-selector/index.tsx b/web/app/components/workflow/block-selector/index.tsx index 4b32ac7503..712434f5c0 100644 --- a/web/app/components/workflow/block-selector/index.tsx +++ b/web/app/components/workflow/block-selector/index.tsx @@ -59,6 +59,10 @@ const NodeSelector: FC = ({ e.stopPropagation() handleOpenChange(!open) }, [open, handleOpenChange]) + const handleSelect = useCallback((type: BlockEnum) => { + handleOpenChange(false) + onSelect(type) + }, [handleOpenChange, onSelect]) return ( = ({ /> - + diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 436cf66475..3905601f63 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -34,6 +34,16 @@ export const NodeInitialData = { retrieval_mode: 'single', }, [BlockEnum.IfElse]: { + branches: [ + { + id: 'if-true', + name: 'IS TRUE', + }, + { + id: 'if-false', + name: 'IS FALSE', + }, + ], type: BlockEnum.IfElse, title: '', desc: '', diff --git a/web/app/components/workflow/hooks.ts b/web/app/components/workflow/hooks.ts index b3c8da9afa..80d60ffa6e 100644 --- a/web/app/components/workflow/hooks.ts +++ b/web/app/components/workflow/hooks.ts @@ -14,6 +14,7 @@ import type { } from './types' import { NodeInitialData } from './constants' import { useStore } from './store' +import { initialNodesPosition } from './utils' export const useWorkflow = () => { const store = useStoreApi() @@ -43,6 +44,7 @@ export const useWorkflow = () => { }) setEdges(newEdges) }, [store]) + const handleLeaveNode = useCallback((_, node) => { const { getNodes, @@ -67,6 +69,7 @@ export const useWorkflow = () => { }) setEdges(newEdges) }, [store]) + const handleEnterEdge = useCallback((_, edge) => { const { edges, @@ -79,6 +82,7 @@ export const useWorkflow = () => { }) setEdges(newEdges) }, [store]) + const handleLeaveEdge = useCallback((_, edge) => { const { edges, @@ -91,6 +95,7 @@ export const useWorkflow = () => { }) setEdges(newEdges) }, [store]) + const handleSelectNode = useCallback((selectNode: SelectedNode, cancelSelection?: boolean) => { const { getNodes, @@ -99,20 +104,18 @@ export const useWorkflow = () => { if (cancelSelection) { setSelectedNode(null) const newNodes = produce(getNodes(), (draft) => { - const currentNode = draft.find(n => n.id === selectNode.id) - - if (currentNode) - currentNode.data = { ...currentNode.data, selected: false } + draft.forEach((item) => { + item.data = { ...item.data, selected: false } + }) }) setNodes(newNodes) } else { setSelectedNode(selectNode) const newNodes = produce(getNodes(), (draft) => { - const currentNode = draft.find(n => n.id === selectNode.id) - - if (currentNode) - currentNode.data = { ...currentNode.data, selected: true } + draft.forEach((item) => { + item.data = { ...item.data, selected: item.id === selectNode.id } + }) }) setNodes(newNodes) } @@ -130,7 +133,8 @@ export const useWorkflow = () => { setNodes(newNodes) setSelectedNode({ id, data }) }, [store, setSelectedNode]) - const handleAddNextNode = useCallback((currentNodeId: string, nodeType: BlockEnum) => { + + const handleAddNextNode = useCallback((currentNodeId: string, nodeType: BlockEnum, branchId?: string) => { const { getNodes, setNodes, @@ -141,24 +145,47 @@ export const useWorkflow = () => { const currentNode = nodes.find(node => node.id === currentNodeId)! const nextNode = { id: `${Date.now()}`, - data: NodeInitialData[nodeType], + type: 'custom', + data: { ...NodeInitialData[nodeType], selected: true }, position: { x: currentNode.position.x + 304, y: currentNode.position.y, }, } const newNodes = produce(nodes, (draft) => { + draft.forEach((item) => { + item.data = { ...item.data, selected: false } + }) draft.push(nextNode) }) setNodes(newNodes) const newEdges = produce(edges, (draft) => { draft.push({ id: `${currentNode.id}-${nextNode.id}`, + type: 'custom', source: currentNode.id, + sourceHandle: branchId || 'source', target: nextNode.id, + targetHandle: 'target', }) }) setEdges(newEdges) + setSelectedNode(nextNode) + }, [store, setSelectedNode]) + const handleInitialLayoutNodes = useCallback(() => { + const { + getNodes, + setNodes, + edges, + setEdges, + } = store.getState() + + setNodes(initialNodesPosition(getNodes(), edges)) + setEdges(produce(edges, (draft) => { + draft.forEach((edge) => { + edge.hidden = false + }) + })) }, [store]) return { @@ -169,5 +196,6 @@ export const useWorkflow = () => { handleSelectNode, handleUpdateNodeData, handleAddNextNode, + handleInitialLayoutNodes, } } diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index dc416de9a5..7b6a9ce526 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -1,10 +1,16 @@ import type { FC } from 'react' -import { memo, useEffect } from 'react' +import { + memo, + useEffect, + useMemo, +} from 'react' +import produce from 'immer' import type { Edge } from 'reactflow' import ReactFlow, { Background, ReactFlowProvider, useEdgesState, + useNodesInitialized, useNodesState, } from 'reactflow' import 'reactflow/dist/style.css' @@ -15,7 +21,7 @@ import ZoomInOut from './zoom-in-out' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' import Panel from './panel' -import type { Node } from './types' +import { BlockEnum, type Node } from './types' const nodeTypes = { custom: CustomNode, @@ -34,8 +40,41 @@ const Workflow: FC = memo(({ edges: initialEdges, selectedNodeId: initialSelectedNodeId, }) => { - const [nodes] = useNodesState(initialNodes) - const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges) + const initialData: { + nodes: Node[] + edges: Edge[] + needUpdatePosition: boolean + } = useMemo(() => { + const start = initialNodes.find(node => node.data.type === BlockEnum.Start) + + if (start?.position) { + return { + nodes: initialNodes, + edges: initialEdges, + needUpdatePosition: false, + } + } + + return { + nodes: produce(initialNodes, (draft) => { + draft.forEach((node) => { + node.position = { x: 0, y: 0 } + node.data = { ...node.data, hidden: true } + }) + }), + edges: produce(initialEdges, (draft) => { + draft.forEach((edge) => { + edge.hidden = true + }) + }), + needUpdatePosition: true, + } + }, [initialNodes, initialEdges]) + const nodesInitialized = useNodesInitialized({ + includeHiddenNodes: true, + }) + const [nodes] = useNodesState(initialData.nodes) + const [edges, setEdges, onEdgesChange] = useEdgesState(initialData.edges) const { handleEnterNode, @@ -43,8 +82,14 @@ const Workflow: FC = memo(({ handleEnterEdge, handleLeaveEdge, handleSelectNode, + handleInitialLayoutNodes, } = useWorkflow() + useEffect(() => { + if (nodesInitialized && initialData.needUpdatePosition) + handleInitialLayoutNodes() + }, [nodesInitialized]) + useEffect(() => { if (initialSelectedNodeId) { const initialSelectedNode = nodes.find(n => n.id === initialSelectedNodeId) diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index 1fc5d7e08e..840ce2e66f 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -7,12 +7,12 @@ import { Handle, Position, getConnectedEdges, - getIncomers, useStoreApi, } from 'reactflow' import { BlockEnum } from '../../../types' import type { Node } from '../../../types' import BlockSelector from '../../../block-selector' +import { useWorkflow } from '../../../hooks' type NodeHandleProps = { handleId?: string @@ -29,12 +29,13 @@ export const NodeTargetHandle = ({ }: NodeHandleProps) => { const [open, setOpen] = useState(false) const store = useStoreApi() - const incomers = getIncomers({ id } as Node, store.getState().getNodes(), store.getState().edges) + const connectedEdges = getConnectedEdges([{ id } as Node], store.getState().edges) + const connected = connectedEdges.find(edge => edge.targetHandle === handleId && edge.target === id) const handleOpenChange = useCallback((v: boolean) => { setOpen(v) }, []) const handleHandleClick = () => { - if (incomers.length === 0 && data.type !== BlockEnum.Start) + if (!connected) handleOpenChange(!open) } @@ -47,7 +48,7 @@ export const NodeTargetHandle = ({ className={` !w-4 !h-4 !bg-transparent !rounded-none !outline-none !border-none !translate-y-0 z-[1] after:absolute after:w-0.5 after:h-2 after:left-1.5 after:top-1 after:bg-primary-500 - ${!incomers.length && 'after:opacity-0'} + ${!connected && 'after:opacity-0'} ${data.type === BlockEnum.Start && 'opacity-0'} ${handleClassName} `} @@ -55,7 +56,7 @@ export const NodeTargetHandle = ({ onClick={handleHandleClick} > { - incomers.length === 0 && data.type !== BlockEnum.Start && ( + !connected && data.type !== BlockEnum.Start && ( { const [open, setOpen] = useState(false) + const { handleAddNextNode } = useWorkflow() const store = useStoreApi() const connectedEdges = getConnectedEdges([{ id } as Node], store.getState().edges) - const connected = connectedEdges.find(edge => edge.sourceHandle === handleId) + const connected = connectedEdges.find(edge => edge.sourceHandle === handleId && edge.source === id) const handleOpenChange = useCallback((v: boolean) => { setOpen(v) }, []) @@ -94,6 +96,9 @@ export const NodeSourceHandle = ({ if (!connected) handleOpenChange(!open) } + const handleSelect = useCallback((type: BlockEnum) => { + handleAddNextNode(id, type) + }, [handleAddNextNode, id]) return ( <> @@ -114,7 +119,7 @@ export const NodeSourceHandle = ({ {}} + onSelect={handleSelect} asChild triggerClassName={open => ` hidden absolute top-0 left-0 pointer-events-none diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 04f78fbce5..0b21e8f81f 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -18,7 +18,6 @@ type BaseNodeProps = { const BaseNode: FC = ({ id: nodeId, data, - selected, children, }) => { const { handleSelectNode } = useWorkflow() @@ -28,7 +27,8 @@ const BaseNode: FC = ({ className={` group relative w-[240px] bg-[#fcfdff] rounded-2xl shadow-xs hover:shadow-lg - ${(data.selected && selected) ? 'border-[2px] border-primary-600' : 'border border-white'} + ${data.hidden && 'opacity-0'} + ${data.selected ? 'border-[2px] border-primary-600' : 'border border-white'} `} onClick={() => handleSelectNode({ id: nodeId, data })} > diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 500ffc8803..cc86706c9f 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -1,4 +1,7 @@ -import type { Node as ReactFlowNode } from 'reactflow' +import type { + Edge as ReactFlowEdge, + Node as ReactFlowNode, +} from 'reactflow' export enum BlockEnum { Start = 'start', @@ -15,15 +18,29 @@ export enum BlockEnum { Tool = 'tool', } +export type Branch = { + id: string + name: string +} + export type CommonNodeType = { + hidden?: boolean + position?: { + x: number + y: number + } + sortIndexInBranches?: number + selected?: boolean + hovering?: boolean + branches?: Branch[] title: string desc: string type: BlockEnum - selected?: boolean } export type Node = ReactFlowNode export type SelectedNode = Pick +export type Edge = ReactFlowEdge export type ValueSelector = string[] // [nodeId, key | obj key path] diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index e69de29bb2..e7e62510cb 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -0,0 +1,56 @@ +import { + getOutgoers, +} from 'reactflow' +import { cloneDeep } from 'lodash-es' +import type { + Edge, + Node, +} from './types' +import { BlockEnum } from './types' + +export const initialNodesPosition = (oldNodes: Node[], edges: Edge[]) => { + const nodes = cloneDeep(oldNodes) + const start = nodes.find(node => node.data.type === BlockEnum.Start)! + + start.data.hidden = false + start.position.x = 0 + start.position.y = 0 + start.data.position = { + x: 0, + y: 0, + } + const queue = [start] + + let depth = 0 + let breadth = 0 + let baseHeight = 0 + while (queue.length) { + const node = queue.shift()! + + if (node.data.position?.x !== depth) { + breadth = 0 + baseHeight = 0 + } + + depth = node.data.position?.x || 0 + + const outgoers = getOutgoers(node, nodes, edges).sort((a, b) => (a.data.sortIndexInBranches || 0) - (b.data.sortIndexInBranches || 0)) + + if (outgoers.length) { + queue.push(...outgoers.map((outgoer) => { + outgoer.data.hidden = false + outgoer.data.position = { + x: depth + 1, + y: breadth, + } + outgoer.position.x = (depth + 1) * (220 + 64) + outgoer.position.y = baseHeight + baseHeight += ((outgoer.height || 0) + 39) + breadth += 1 + return outgoer + })) + } + } + + return nodes +}