From e90086c2d2ad9bf60d8970d336844bd50b5fedd2 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Thu, 16 Oct 2025 18:30:15 +0800 Subject: [PATCH] add memory variable --- .../workflow-app/components/workflow-main.tsx | 6 +- .../hooks/use-nodes-sync-draft.ts | 2 +- .../workflow-app/hooks/use-workflow-init.ts | 4 +- .../hooks/use-workflow-refresh-draft.ts | 6 +- web/app/components/workflow/hooks/index.ts | 1 + .../workflow/hooks/use-memory-variable.ts | 119 ++++++++++++++++++ .../components/node-selector.tsx | 13 +- .../panel/chat-variable-panel/index.tsx | 12 +- web/app/components/workflow/types.ts | 2 + .../components/workflow/update-dsl-modal.tsx | 9 +- 10 files changed, 154 insertions(+), 20 deletions(-) create mode 100644 web/app/components/workflow/hooks/use-memory-variable.ts diff --git a/web/app/components/workflow-app/components/workflow-main.tsx b/web/app/components/workflow-app/components/workflow-main.tsx index fd7f3d17af..f639eda08d 100644 --- a/web/app/components/workflow-app/components/workflow-main.tsx +++ b/web/app/components/workflow-app/components/workflow-main.tsx @@ -19,6 +19,7 @@ import { useWorkflowStartRun, } from '../hooks' import { useWorkflowStore } from '@/app/components/workflow/store' +import { useFormatMemoryVariables } from '@/app/components/workflow/hooks' type WorkflowMainProps = Pick const WorkflowMain = ({ @@ -28,6 +29,7 @@ const WorkflowMain = ({ }: WorkflowMainProps) => { const featuresStore = useFeaturesStore() const workflowStore = useWorkflowStore() + const { formatMemoryVariables } = useFormatMemoryVariables() const handleWorkflowDataUpdate = useCallback((payload: any) => { const { @@ -51,9 +53,9 @@ const WorkflowMain = ({ } if (memory_blocks) { const { setMemoryVariables } = workflowStore.getState() - setMemoryVariables(memory_blocks) + setMemoryVariables(formatMemoryVariables(memory_blocks, nodes)) } - }, [featuresStore, workflowStore]) + }, [featuresStore, workflowStore, formatMemoryVariables]) const { doSyncWorkflowDraft, diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 1f8bff8cd5..2ffb3a50cb 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -85,7 +85,7 @@ export const useNodesSyncDraft = () => { }, environment_variables: environmentVariables, conversation_variables: conversationVariables, - memory_blocks: memoryVariables, + memory_blocks: memoryVariables.map(({ node, value_type, more, ...rest }) => rest), hash: syncWorkflowDraftHash, }, } diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts index bb0d5b54ff..70049a3380 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-init.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -18,6 +18,7 @@ import { import type { FetchWorkflowDraftResponse } from '@/types/workflow' import { useWorkflowConfig } from '@/service/use-workflow' import type { FileUploadConfigResponse } from '@/models/common' +import { useFormatMemoryVariables } from '@/app/components/workflow/hooks' export const useWorkflowInit = () => { const workflowStore = useWorkflowStore() @@ -41,6 +42,7 @@ export const useWorkflowInit = () => { data: fileUploadConfigResponse, isLoading: isFileUploadConfigLoading, } = useWorkflowConfig('/files/upload', handleUpdateWorkflowFileUploadConfig) + const { formatMemoryVariables } = useFormatMemoryVariables() const handleGetInitialWorkflowData = useCallback(async () => { try { @@ -53,7 +55,7 @@ export const useWorkflowInit = () => { }, {} as Record), environmentVariables: res.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || [], conversationVariables: res.conversation_variables || [], - memoryVariables: res.memory_blocks || [], + memoryVariables: formatMemoryVariables((res.memory_blocks || []), res.graph.nodes), }) setSyncWorkflowDraftHash(res.hash) setIsLoading(false) diff --git a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts index 3a2fa18b19..110f96fee9 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts @@ -3,10 +3,12 @@ import { useWorkflowStore } from '@/app/components/workflow/store' import { fetchWorkflowDraft } from '@/service/workflow' import type { WorkflowDataUpdater } from '@/app/components/workflow/types' import { useWorkflowUpdate } from '@/app/components/workflow/hooks' +import { useFormatMemoryVariables } from '@/app/components/workflow/hooks' export const useWorkflowRefreshDraft = () => { const workflowStore = useWorkflowStore() const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() + const { formatMemoryVariables } = useFormatMemoryVariables() const handleRefreshWorkflowDraft = useCallback(() => { const { @@ -28,9 +30,9 @@ export const useWorkflowRefreshDraft = () => { }, {} as Record)) setEnvironmentVariables(response.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || []) setConversationVariables(response.conversation_variables || []) - setMemoryVariables(response.memory_blocks || []) + setMemoryVariables(formatMemoryVariables((response.memory_blocks || []), response.graph.nodes)) }).finally(() => setIsSyncingWorkflowDraft(false)) - }, [handleUpdateWorkflowCanvas, workflowStore]) + }, [handleUpdateWorkflowCanvas, workflowStore, formatMemoryVariables]) return { handleRefreshWorkflowDraft, diff --git a/web/app/components/workflow/hooks/index.ts b/web/app/components/workflow/hooks/index.ts index 1dbba6b0e2..09e16539fc 100644 --- a/web/app/components/workflow/hooks/index.ts +++ b/web/app/components/workflow/hooks/index.ts @@ -22,3 +22,4 @@ export * from './use-DSL' export * from './use-inspect-vars-crud' export * from './use-set-workflow-vars-with-value' export * from './use-workflow-search' +export * from './use-memory-variable' diff --git a/web/app/components/workflow/hooks/use-memory-variable.ts b/web/app/components/workflow/hooks/use-memory-variable.ts new file mode 100644 index 0000000000..04aac8ca7f --- /dev/null +++ b/web/app/components/workflow/hooks/use-memory-variable.ts @@ -0,0 +1,119 @@ +import { useCallback } from 'react' +import { useStoreApi } from 'reactflow' +import produce from 'immer' +import { + useStore, + useWorkflowStore, +} from '@/app/components/workflow/store' +import { BlockEnum } from '@/app/components/workflow/types' +import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' +import type { MemoryVariable, Node } from '@/app/components/workflow/types' +import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types' + +export const useMemoryVariable = () => { + const workflowStore = useWorkflowStore() + const setMemoryVariables = useStore(s => s.setMemoryVariables) + const store = useStoreApi() + + const handleAddMemoryVariableToNode = useCallback((nodeId: string, memoryVariableId: string) => { + const { getNodes, setNodes } = store.getState() + const nodes = getNodes() + const newNodes = produce(nodes, (draft) => { + const currentNode = draft.find(n => n.id === nodeId) + if (currentNode) { + currentNode.data.memory = { + ...(currentNode.data.memory || {}), + block_id: [...(currentNode.data.memory?.block_id || []), memoryVariableId], + } + } + }) + setNodes(newNodes) + }, [store]) + + const handleDeleteMemoryVariableFromNode = useCallback((nodeId: string, memoryVariableId: string) => { + const { getNodes, setNodes } = store.getState() + const nodes = getNodes() + const newNodes = produce(nodes, (draft) => { + const currentNode = draft.find(n => n.id === nodeId) + if (currentNode) { + currentNode.data.memory = { + ...(currentNode.data.memory || {}), + block_id: currentNode.data.memory?.block_id?.filter((id: string) => id !== memoryVariableId) || [], + } + } + }) + setNodes(newNodes) + }, [store]) + + const handleAddMemoryVariable = useCallback((memoryVariable: MemoryVariable) => { + const { memoryVariables } = workflowStore.getState() + setMemoryVariables([memoryVariable, ...memoryVariables]) + + if (memoryVariable.node) + handleAddMemoryVariableToNode(memoryVariable.node, memoryVariable.id) + }, [setMemoryVariables, workflowStore, handleAddMemoryVariableToNode]) + + const handleUpdateMemoryVariable = useCallback((memoryVariable: MemoryVariable) => { + const { memoryVariables } = workflowStore.getState() + const oldMemoryVariable = memoryVariables.find(v => v.id === memoryVariable.id) + setMemoryVariables(memoryVariables.map(v => v.id === memoryVariable.id ? memoryVariable : v)) + + if (oldMemoryVariable && !oldMemoryVariable?.node && memoryVariable.node) + handleAddMemoryVariableToNode(memoryVariable.node, memoryVariable.id) + else if (oldMemoryVariable && oldMemoryVariable.node && !memoryVariable.node) + handleDeleteMemoryVariableFromNode(oldMemoryVariable.node, memoryVariable.id) + }, [setMemoryVariables, workflowStore, handleAddMemoryVariableToNode, handleDeleteMemoryVariableFromNode]) + + const handleDeleteMemoryVariable = useCallback((memoryVariable: MemoryVariable) => { + const { memoryVariables } = workflowStore.getState() + setMemoryVariables(memoryVariables.filter(v => v.id !== memoryVariable.id)) + + if (memoryVariable.node) + handleDeleteMemoryVariableFromNode(memoryVariable.node, memoryVariable.id) + }, [setMemoryVariables, workflowStore, handleDeleteMemoryVariableFromNode]) + + return { + handleAddMemoryVariable, + handleUpdateMemoryVariable, + handleDeleteMemoryVariable, + } +} + +export const useFormatMemoryVariables = () => { + const formatMemoryVariables = useCallback((memoryVariables: MemoryVariable[], nodes: Node[]) => { + let clonedMemoryVariables = [...memoryVariables] + const nodeScopeMemoryVariablesIds = clonedMemoryVariables.filter(v => v.scope === 'node').map(v => v.id) + const nodeScopeMemoryVariablesMap = nodeScopeMemoryVariablesIds.reduce((acc, id) => { + acc[id] = id + return acc + }, {} as Record) + + if (!!nodeScopeMemoryVariablesIds.length) { + const llmNodes = nodes.filter(n => n.data.type === BlockEnum.LLM) + + clonedMemoryVariables = clonedMemoryVariables.map((v) => { + if (nodeScopeMemoryVariablesMap[v.id]) { + const node = llmNodes.find(n => ((n.data as LLMNodeType).memory?.block_id || []).includes(v.id)) + + return { + ...v, + node: node?.id, + } + } + + return v + }) + } + + return clonedMemoryVariables.map((v) => { + return { + ...v, + value_type: ChatVarType.Memory, + } + }) + }, []) + + return { + formatMemoryVariables, + } +} diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/node-selector.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/node-selector.tsx index 246295ecba..95371bf43d 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/node-selector.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/node-selector.tsx @@ -5,8 +5,9 @@ import { RiArrowDownSLine, RiCheckLine, } from '@remixicon/react' +import { useShallow } from 'zustand/react/shallow' import { - useNodes, + useStore, } from 'reactflow' import { PortalToFollowElem, @@ -14,9 +15,6 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import BlockIcon from '@/app/components/workflow/block-icon' -import type { - CommonNodeType, -} from '@/app/components/workflow/types' import { BlockEnum } from '@/app/components/workflow/types' import cn from '@/utils/classnames' @@ -32,12 +30,13 @@ const NodeSelector: FC = ({ nodeType = BlockEnum.LLM, }) => { const [open, setOpen] = useState(false) - const nodes = useNodes() - const filteredNodes = nodeType ? nodes.filter(node => node.data?.type === nodeType) : nodes + const filteredNodes = useStore(useShallow((s) => { + const nodes = [...s.nodeInternals.values()] + return nodes.filter(node => node.data?.type === nodeType) + })) const currentNode = useMemo(() => filteredNodes.find(node => node.id === value), [filteredNodes, value]) - return ( { const { t } = useTranslation() const docLink = useDocLink() const store = useStoreApi() const workflowStore = useWorkflowStore() + const { handleAddMemoryVariable, handleUpdateMemoryVariable, handleDeleteMemoryVariable } = useMemoryVariable() const setShowChatVariablePanel = useStore(s => s.setShowChatVariablePanel) const varList = useStore(s => s.conversationVariables) as ConversationVariable[] const memoryVariables = useStore(s => s.memoryVariables) as MemoryVariable[] @@ -90,8 +92,8 @@ const ChatVariablePanel = () => { removeUsedVarInNodes(chatVar) const varList = workflowStore.getState().conversationVariables updateChatVarList(varList.filter(v => v.id !== chatVar.id)) - const memoryList = workflowStore.getState().memoryVariables - setMemoryVariables(memoryList.filter(v => v.id !== chatVar.id)) + if (chatVar.value_type === ChatVarType.Memory) + handleDeleteMemoryVariable(chatVar as MemoryVariable) setCacheForDelete(undefined) setShowRemoveConfirm(false) handleVarChanged(chatVar.value_type === ChatVarType.Memory) @@ -110,8 +112,10 @@ const ChatVariablePanel = () => { const handleSave = useCallback(async (chatVar: ConversationVariable | MemoryVariable) => { if (chatVar.value_type === ChatVarType.Memory) { - const memoryVarList = workflowStore.getState().memoryVariables - setMemoryVariables([chatVar, ...memoryVarList]) + if (!currentVar) + handleAddMemoryVariable(chatVar as MemoryVariable) + else + handleUpdateMemoryVariable(chatVar as MemoryVariable) handleVarChanged(true) return } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 6e719ab847..8f348c4dab 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -179,6 +179,8 @@ export type MemoryVariable = { term?: string end_user_editable?: boolean value_type: ChatVarType + node?: string + more?: boolean } export type ConversationVariable = { diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index 795a2756c0..d63cc4d8f5 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -39,6 +39,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import { useStore as useAppStore } from '@/app/components/app/store' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { useFormatMemoryVariables } from '@/app/components/workflow/hooks' type UpdateDSLModalProps = { onCancel: () => void @@ -80,6 +81,7 @@ const UpdateDSLModal = ({ if (!file) setFileContent('') } + const { formatMemoryVariables } = useFormatMemoryVariables() const handleWorkflowUpdate = useCallback(async (app_id: string) => { const { @@ -117,20 +119,21 @@ const UpdateDSLModal = ({ moderation: features.sensitive_word_avoidance || { enabled: false }, } + const formattedNodes = initialNodes(nodes, edges) eventEmitter?.emit({ type: WORKFLOW_DATA_UPDATE, payload: { - nodes: initialNodes(nodes, edges), + nodes: formattedNodes, edges: initialEdges(edges, nodes), viewport, features: newFeatures, hash, conversation_variables: conversation_variables || [], environment_variables: environment_variables || [], - memory_blocks: memory_blocks || [], + memory_blocks: formatMemoryVariables(memory_blocks || [], formattedNodes), }, } as any) - }, [eventEmitter]) + }, [eventEmitter, formatMemoryVariables]) const isCreatingRef = useRef(false) const handleImport: MouseEventHandler = useCallback(async () => {