diff --git a/web/app/components/workflow/block-selector/tabs.tsx b/web/app/components/workflow/block-selector/tabs.tsx
index c7c4ddb61e..fbcc31dd59 100644
--- a/web/app/components/workflow/block-selector/tabs.tsx
+++ b/web/app/components/workflow/block-selector/tabs.tsx
@@ -1,5 +1,10 @@
-import { useState } from 'react'
+import {
+ memo,
+ useState,
+} from 'react'
+import { useNodeId } from 'reactflow'
import BlockIcon from '../block-icon'
+import { useWorkflowContext } from '../context'
import {
BLOCK_CLASSIFICATIONS,
BLOCK_GROUP_BY_CLASSIFICATION,
@@ -7,7 +12,13 @@ import {
} from './constants'
const Tabs = () => {
+ const {
+ nodes,
+ handleAddNextNode,
+ } = useWorkflowContext()
const [activeTab, setActiveTab] = useState(TABS[0].key)
+ const nodeId = useNodeId()
+ const currentNode = nodes.find(node => node.id === nodeId)
return (
@@ -46,6 +57,10 @@ const Tabs = () => {
{
+ e.stopPropagation()
+ handleAddNextNode(currentNode!, block.type)
+ }}
>
{
)
}
-export default Tabs
+export default memo(Tabs)
diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts
new file mode 100644
index 0000000000..436cf66475
--- /dev/null
+++ b/web/app/components/workflow/constants.ts
@@ -0,0 +1,84 @@
+import { BlockEnum } from './types'
+
+export const NodeInitialData = {
+ [BlockEnum.Start]: {
+ type: BlockEnum.Start,
+ title: '',
+ desc: '',
+ variables: [],
+ },
+ [BlockEnum.End]: {
+ type: BlockEnum.End,
+ title: '',
+ desc: '',
+ outputs: {},
+ },
+ [BlockEnum.DirectAnswer]: {
+ type: BlockEnum.DirectAnswer,
+ title: '',
+ desc: '',
+ variables: [],
+ },
+ [BlockEnum.LLM]: {
+ type: BlockEnum.LLM,
+ title: '',
+ desc: '',
+ variables: [],
+ },
+ [BlockEnum.KnowledgeRetrieval]: {
+ type: BlockEnum.KnowledgeRetrieval,
+ title: '',
+ desc: '',
+ query_variable_selector: [],
+ dataset_ids: [],
+ retrieval_mode: 'single',
+ },
+ [BlockEnum.IfElse]: {
+ type: BlockEnum.IfElse,
+ title: '',
+ desc: '',
+ logical_operator: 'and',
+ conditions: [],
+ },
+ [BlockEnum.Code]: {
+ type: BlockEnum.Code,
+ title: '',
+ desc: '',
+ variables: [],
+ code_language: 'python3',
+ code: '',
+ outputs: [],
+ },
+ [BlockEnum.TemplateTransform]: {
+ type: BlockEnum.TemplateTransform,
+ title: '',
+ desc: '',
+ variables: [],
+ template: '',
+ },
+ [BlockEnum.QuestionClassifier]: {
+ type: BlockEnum.QuestionClassifier,
+ title: '',
+ desc: '',
+ query_variable_selector: [],
+ topics: [],
+ },
+ [BlockEnum.HttpRequest]: {
+ type: BlockEnum.HttpRequest,
+ title: '',
+ desc: '',
+ variables: [],
+ },
+ [BlockEnum.VariableAssigner]: {
+ type: BlockEnum.VariableAssigner,
+ title: '',
+ desc: '',
+ variables: [],
+ output_type: '',
+ },
+ [BlockEnum.Tool]: {
+ type: BlockEnum.Tool,
+ title: '',
+ desc: '',
+ },
+}
diff --git a/web/app/components/workflow/context.tsx b/web/app/components/workflow/context.tsx
index 6b129165a2..1a82ce8f0b 100644
--- a/web/app/components/workflow/context.tsx
+++ b/web/app/components/workflow/context.tsx
@@ -1,20 +1,30 @@
'use client'
import { createContext, useContext } from 'use-context-selector'
-import type { Edge } from 'reactflow'
-import type { Node } from './types'
+import type {
+ Edge,
+ ReactFlowInstance,
+} from 'reactflow'
+import type {
+ BlockEnum,
+ Node,
+} from './types'
export type WorkflowContextValue = {
+ reactFlow: ReactFlowInstance
nodes: Node[]
edges: Edge[]
selectedNodeId?: string
handleSelectedNodeIdChange: (nodeId: string) => void
selectedNode?: Node
+ handleAddNextNode: (prevNode: Node, nextNodeType: BlockEnum) => void
}
export const WorkflowContext = createContext({
+ reactFlow: null as any,
nodes: [],
edges: [],
handleSelectedNodeIdChange: () => {},
+ handleAddNextNode: () => {},
})
export const useWorkflowContext = () => useContext(WorkflowContext)
diff --git a/web/app/components/workflow/hooks.ts b/web/app/components/workflow/hooks.ts
index 631d169f52..c7fb940492 100644
--- a/web/app/components/workflow/hooks.ts
+++ b/web/app/components/workflow/hooks.ts
@@ -1,11 +1,27 @@
+import type {
+ Dispatch,
+ SetStateAction,
+} from 'react'
import {
useCallback,
useMemo,
useState,
} from 'react'
-import type { Node } from './types'
+import produce from 'immer'
+import type { Edge } from 'reactflow'
+import type {
+ BlockEnum,
+ Node,
+} from './types'
+import { NodeInitialData } from './constants'
-export const useWorkflow = (nodes: Node[], initialSelectedNodeId?: string) => {
+export const useWorkflow = (
+ nodes: Node[],
+ edges: Edge[],
+ setNodes: Dispatch>,
+ setEdges: Dispatch>,
+ initialSelectedNodeId?: string,
+) => {
const [selectedNodeId, setSelectedNodeId] = useState(initialSelectedNodeId)
const handleSelectedNodeIdChange = useCallback((nodeId: string) => setSelectedNodeId(nodeId), [])
@@ -14,9 +30,40 @@ export const useWorkflow = (nodes: Node[], initialSelectedNodeId?: string) => {
return nodes.find(node => node.id === selectedNodeId)
}, [nodes, selectedNodeId])
+ const handleAddNextNode = useCallback((prevNode: Node, nextNodeType: BlockEnum) => {
+ const prevNodeDom = document.querySelector(`.react-flow__node-custom[data-id="${prevNode.id}"]`)
+ const prevNodeDomHeight = prevNodeDom?.getBoundingClientRect().height || 0
+
+ const nextNode = {
+ id: `node-${Date.now()}`,
+ type: 'custom',
+ position: {
+ x: prevNode.position.x,
+ y: prevNode.position.y + prevNodeDomHeight + 64,
+ },
+ data: NodeInitialData[nextNodeType],
+ }
+ const newEdge = {
+ id: `edge-${Date.now()}`,
+ source: prevNode.id,
+ target: nextNode.id,
+ }
+ setNodes((oldNodes) => {
+ return produce(oldNodes, (draft) => {
+ draft.push(nextNode)
+ })
+ })
+ setEdges((oldEdges) => {
+ return produce(oldEdges, (draft) => {
+ draft.push(newEdge)
+ })
+ })
+ }, [setNodes, setEdges])
+
return {
selectedNodeId,
selectedNode,
handleSelectedNodeIdChange,
+ handleAddNextNode,
}
}
diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx
index a64be7259b..f0f1d536b1 100644
--- a/web/app/components/workflow/index.tsx
+++ b/web/app/components/workflow/index.tsx
@@ -5,6 +5,7 @@ import ReactFlow, {
ReactFlowProvider,
useEdgesState,
useNodesState,
+ useReactFlow,
} from 'reactflow'
import 'reactflow/dist/style.css'
import {
@@ -60,21 +61,31 @@ const WorkflowWrap: FC = ({
edges: initialEdges,
selectedNodeId: initialSelectedNodeId,
}) => {
- const [nodes] = useNodesState(initialNodes)
- const [edges] = useEdgesState(initialEdges)
+ const reactFlow = useReactFlow()
+ const [nodes, setNodes] = useNodesState(initialNodes)
+ const [edges, setEdges] = useEdgesState(initialEdges)
const {
selectedNodeId,
handleSelectedNodeIdChange,
selectedNode,
- } = useWorkflow(nodes, initialSelectedNodeId)
+ handleAddNextNode,
+ } = useWorkflow(
+ nodes,
+ edges,
+ setNodes,
+ setEdges,
+ initialSelectedNodeId,
+ )
return (
diff --git a/web/app/components/workflow/nodes/constants.ts b/web/app/components/workflow/nodes/constants.ts
index 66f1464eb1..08dbb99793 100644
--- a/web/app/components/workflow/nodes/constants.ts
+++ b/web/app/components/workflow/nodes/constants.ts
@@ -1,4 +1,5 @@
import type { ComponentType } from 'react'
+import { BlockEnum } from '../types'
import StartNode from './start/node'
import StartPanel from './start/panel'
import EndNode from './end/node'
@@ -23,29 +24,29 @@ import ToolNode from './tool/node'
import ToolPanel from './tool/panel'
export const NodeMap: Record = {
- start: StartNode,
- end: EndNode,
- directAnswer: DirectAnswerNode,
- llm: LLMNode,
- knowledgeRetrieval: KnowledgeRetrievalNode,
- questionClassifier: QuestionClassifierNode,
- ifElse: IfElseNode,
- code: CodeNode,
- templateTransform: TemplateTransformNode,
- http: HttpNode,
- tool: ToolNode,
+ [BlockEnum.Start]: StartNode,
+ [BlockEnum.End]: EndNode,
+ [BlockEnum.DirectAnswer]: DirectAnswerNode,
+ [BlockEnum.LLM]: LLMNode,
+ [BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalNode,
+ [BlockEnum.QuestionClassifier]: QuestionClassifierNode,
+ [BlockEnum.IfElse]: IfElseNode,
+ [BlockEnum.Code]: CodeNode,
+ [BlockEnum.TemplateTransform]: TemplateTransformNode,
+ [BlockEnum.HttpRequest]: HttpNode,
+ [BlockEnum.Tool]: ToolNode,
}
export const PanelMap: Record = {
- start: StartPanel,
- end: EndPanel,
- directAnswer: DirectAnswerPanel,
- llm: LLMPanel,
- knowledgeRetrieval: KnowledgeRetrievalPanel,
- questionClassifier: QuestionClassifierPanel,
- ifElse: IfElsePanel,
- code: CodePanel,
- templateTransform: TemplateTransformPanel,
- http: HttpPanel,
- tool: ToolPanel,
+ [BlockEnum.Start]: StartPanel,
+ [BlockEnum.End]: EndPanel,
+ [BlockEnum.DirectAnswer]: DirectAnswerPanel,
+ [BlockEnum.LLM]: LLMPanel,
+ [BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalPanel,
+ [BlockEnum.QuestionClassifier]: QuestionClassifierPanel,
+ [BlockEnum.IfElse]: IfElsePanel,
+ [BlockEnum.Code]: CodePanel,
+ [BlockEnum.TemplateTransform]: TemplateTransformPanel,
+ [BlockEnum.HttpRequest]: HttpPanel,
+ [BlockEnum.Tool]: ToolPanel,
}
diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts
index bb8cad7611..8fd1f35d8b 100644
--- a/web/app/components/workflow/types.ts
+++ b/web/app/components/workflow/types.ts
@@ -13,6 +13,7 @@ export enum BlockEnum {
TemplateTransform = 'template-transform',
HttpRequest = 'http-request',
VariableAssigner = 'variable-assigner',
+ Tool = 'tool',
}
export type NodeData = {