mirror of https://github.com/langgenius/dify.git
296 lines
7.8 KiB
TypeScript
296 lines
7.8 KiB
TypeScript
import { useCallback } from 'react'
|
|
import produce from 'immer'
|
|
import type {
|
|
EdgeMouseHandler,
|
|
NodeMouseHandler,
|
|
} from 'reactflow'
|
|
import {
|
|
getConnectedEdges,
|
|
useStoreApi,
|
|
} from 'reactflow'
|
|
import type {
|
|
BlockEnum,
|
|
Node,
|
|
SelectedNode,
|
|
} from './types'
|
|
import { NodeInitialData } from './constants'
|
|
import { useStore } from './store'
|
|
import { initialNodesPosition } from './utils'
|
|
|
|
export const useWorkflow = () => {
|
|
const store = useStoreApi()
|
|
const setSelectedNode = useStore(state => state.setSelectedNode)
|
|
|
|
const handleEnterNode = useCallback<NodeMouseHandler>((_, node) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const newNodes = produce(getNodes(), (draft) => {
|
|
const currentNode = draft.find(n => n.id === node.id)!
|
|
|
|
currentNode.data.hovering = true
|
|
})
|
|
setNodes(newNodes)
|
|
const newEdges = produce(edges, (draft) => {
|
|
const connectedEdges = getConnectedEdges([node], edges)
|
|
|
|
connectedEdges.forEach((edge) => {
|
|
const currentEdge = draft.find(e => e.id === edge.id)
|
|
if (currentEdge)
|
|
currentEdge.data = { ...currentEdge.data, connectedNodeIsHovering: true }
|
|
})
|
|
})
|
|
setEdges(newEdges)
|
|
}, [store])
|
|
|
|
const handleLeaveNode = useCallback<NodeMouseHandler>((_, node) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const newNodes = produce(getNodes(), (draft) => {
|
|
const currentNode = draft.find(n => n.id === node.id)!
|
|
|
|
currentNode.data.hovering = false
|
|
})
|
|
setNodes(newNodes)
|
|
const newEdges = produce(edges, (draft) => {
|
|
const connectedEdges = getConnectedEdges([node], edges)
|
|
|
|
connectedEdges.forEach((edge) => {
|
|
const currentEdge = draft.find(e => e.id === edge.id)
|
|
if (currentEdge)
|
|
currentEdge.data = { ...currentEdge.data, connectedNodeIsHovering: false }
|
|
})
|
|
})
|
|
setEdges(newEdges)
|
|
}, [store])
|
|
|
|
const handleEnterEdge = useCallback<EdgeMouseHandler>((_, edge) => {
|
|
const {
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const newEdges = produce(edges, (draft) => {
|
|
const currentEdge = draft.find(e => e.id === edge.id)
|
|
if (currentEdge)
|
|
currentEdge.data = { ...currentEdge.data, hovering: true }
|
|
})
|
|
setEdges(newEdges)
|
|
}, [store])
|
|
|
|
const handleLeaveEdge = useCallback<EdgeMouseHandler>((_, edge) => {
|
|
const {
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const newEdges = produce(edges, (draft) => {
|
|
const currentEdge = draft.find(e => e.id === edge.id)
|
|
if (currentEdge)
|
|
currentEdge.data = { ...currentEdge.data, hovering: false }
|
|
})
|
|
setEdges(newEdges)
|
|
}, [store])
|
|
|
|
const handleSelectNode = useCallback((selectNode: SelectedNode, cancelSelection?: boolean) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
} = store.getState()
|
|
if (cancelSelection) {
|
|
setSelectedNode(null)
|
|
const newNodes = produce(getNodes(), (draft) => {
|
|
draft.forEach((item) => {
|
|
item.data.selected = false
|
|
})
|
|
})
|
|
setNodes(newNodes)
|
|
}
|
|
else {
|
|
setSelectedNode(selectNode)
|
|
const newNodes = produce(getNodes(), (draft) => {
|
|
draft.forEach((item) => {
|
|
if (item.id === selectNode.id)
|
|
item.data.selected = true
|
|
else
|
|
item.data.selected = false
|
|
})
|
|
})
|
|
setNodes(newNodes)
|
|
}
|
|
}, [setSelectedNode, store])
|
|
|
|
const handleUpdateNodeData = useCallback(({ id, data }: SelectedNode) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
} = store.getState()
|
|
const newNodes = produce(getNodes(), (draft) => {
|
|
const currentNode = draft.find(node => node.id === id)!
|
|
|
|
currentNode.data = { ...currentNode.data, ...data }
|
|
})
|
|
setNodes(newNodes)
|
|
setSelectedNode({ id, data })
|
|
}, [store, setSelectedNode])
|
|
|
|
const handleAddNextNode = useCallback((currentNodeId: string, nodeType: BlockEnum, sourceHandle: string) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const nodes = getNodes()
|
|
const currentNode = nodes.find(node => node.id === currentNodeId)!
|
|
const nextNode: Node = {
|
|
id: `${Date.now()}`,
|
|
type: 'custom',
|
|
data: {
|
|
...NodeInitialData[nodeType],
|
|
selected: true,
|
|
},
|
|
position: {
|
|
x: currentNode.position.x + 304,
|
|
y: currentNode.position.y,
|
|
},
|
|
}
|
|
const newEdge = {
|
|
id: `${currentNode.id}-${nextNode.id}`,
|
|
type: 'custom',
|
|
source: currentNode.id,
|
|
sourceHandle,
|
|
target: nextNode.id,
|
|
targetHandle: 'target',
|
|
}
|
|
const newNodes = produce(nodes, (draft) => {
|
|
draft.forEach((node) => {
|
|
node.data = { ...node.data, selected: false }
|
|
})
|
|
draft.push(nextNode)
|
|
})
|
|
setNodes(newNodes)
|
|
const newEdges = produce(edges, (draft) => {
|
|
draft.push(newEdge)
|
|
})
|
|
setEdges(newEdges)
|
|
setSelectedNode(nextNode)
|
|
}, [store, setSelectedNode])
|
|
|
|
const handleChangeCurrentNode = useCallback((parentNodeId: string, currentNodeId: string, nodeType: BlockEnum, sourceHandle: string) => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
const nodes = getNodes()
|
|
const currentNode = nodes.find(node => node.id === currentNodeId)!
|
|
const connectedEdges = getConnectedEdges([currentNode], edges)
|
|
const newCurrentNode: Node = {
|
|
id: `${Date.now()}`,
|
|
type: 'custom',
|
|
data: {
|
|
...NodeInitialData[nodeType],
|
|
selected: true,
|
|
},
|
|
position: {
|
|
x: currentNode.position.x,
|
|
y: currentNode.position.y,
|
|
},
|
|
}
|
|
const newEdge = {
|
|
id: `${parentNodeId}-${newCurrentNode.id}`,
|
|
type: 'custom',
|
|
source: parentNodeId,
|
|
sourceHandle,
|
|
target: newCurrentNode.id,
|
|
targetHandle: 'target',
|
|
}
|
|
const newNodes = produce(nodes, (draft) => {
|
|
const index = draft.findIndex(node => node.id === currentNodeId)
|
|
|
|
draft.splice(index, 1, newCurrentNode)
|
|
})
|
|
setNodes(newNodes)
|
|
const newEdges = produce(edges, (draft) => {
|
|
const filtered = draft.filter(edge => !connectedEdges.find(connectedEdge => connectedEdge.id === edge.id))
|
|
filtered.push(newEdge)
|
|
|
|
return filtered
|
|
})
|
|
setEdges(newEdges)
|
|
}, [store])
|
|
|
|
const handleInitialLayoutNodes = useCallback(() => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
edges,
|
|
setEdges,
|
|
} = store.getState()
|
|
|
|
setNodes(initialNodesPosition(getNodes(), edges))
|
|
setEdges(produce(edges, (draft) => {
|
|
draft.forEach((edge) => {
|
|
edge.hidden = false
|
|
})
|
|
}))
|
|
}, [store])
|
|
|
|
const handleUpdateNodesPosition = useCallback(() => {
|
|
const {
|
|
getNodes,
|
|
setNodes,
|
|
} = store.getState()
|
|
|
|
const nodes = getNodes()
|
|
const groups = nodes.reduce((acc, cur) => {
|
|
const x = cur.data.position.x
|
|
|
|
if (!acc[x])
|
|
acc[x] = [cur]
|
|
else
|
|
acc[x].push(cur)
|
|
|
|
return acc
|
|
}, {} as Record<string, Node[]>)
|
|
const heightMap: Record<string, number> = {}
|
|
|
|
Object.keys(groups).forEach((key) => {
|
|
let baseHeight = 0
|
|
groups[key].sort((a, b) => a.data.position!.y - b.data.position!.y).forEach((node) => {
|
|
heightMap[node.id] = baseHeight
|
|
baseHeight = node.height! + 39
|
|
})
|
|
})
|
|
setNodes(produce(nodes, (draft) => {
|
|
draft.forEach((node) => {
|
|
node.position = {
|
|
...node.position,
|
|
x: node.data.position.x * (220 + 64),
|
|
y: heightMap[node.id],
|
|
}
|
|
})
|
|
}))
|
|
}, [store])
|
|
|
|
return {
|
|
handleEnterNode,
|
|
handleLeaveNode,
|
|
handleEnterEdge,
|
|
handleLeaveEdge,
|
|
handleSelectNode,
|
|
handleUpdateNodeData,
|
|
handleAddNextNode,
|
|
handleChangeCurrentNode,
|
|
handleInitialLayoutNodes,
|
|
handleUpdateNodesPosition,
|
|
}
|
|
}
|