diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.syncNodes.test.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.syncNodes.test.ts index 5c6596c9d6..f2ddc9f833 100644 --- a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.syncNodes.test.ts +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.syncNodes.test.ts @@ -17,6 +17,12 @@ type WorkflowVariable = { variable: string } +type PromptTemplateItem = { + id: string + role: string + text: string +} + const createVariable = (name: string, overrides: Partial = {}): WorkflowVariable => ({ default: '', hint: '', @@ -53,6 +59,43 @@ const createNodeSnapshot = (variableNames: string[]): Node<{ variables: Workflow }, }) +const LLM_NODE_ID = 'llm-node' + +const createLLMNodeSnapshot = (promptTemplates: PromptTemplateItem[]): Node => ({ + id: LLM_NODE_ID, + type: 'custom', + position: { x: 200, y: 120 }, + positionAbsolute: { x: 200, y: 120 }, + height: 320, + width: 460, + selected: false, + selectable: true, + draggable: true, + sourcePosition: 'right', + targetPosition: 'left', + data: { + type: 'llm', + title: 'LLM', + selected: false, + context: { + enabled: false, + variable_selector: [], + }, + model: { + mode: 'chat', + name: 'gemini-2.5-pro', + provider: 'langgenius/gemini/google', + completion_params: { + temperature: 0.7, + }, + }, + vision: { + enabled: false, + }, + prompt_template: promptTemplates, + }, +}) + const getVariables = (node: Node): string[] => { const variables = (node.data as any)?.variables ?? [] return variables.map((item: WorkflowVariable) => item.variable) @@ -63,6 +106,10 @@ const getVariableObject = (node: Node, name: string): WorkflowVariable | undefin return variables.find((item: WorkflowVariable) => item.variable === name) } +const getPromptTemplates = (node: Node): PromptTemplateItem[] => { + return ((node.data as any)?.prompt_template ?? []) as PromptTemplateItem[] +} + describe('CollaborationManager syncNodes', () => { let manager: CollaborationManager @@ -177,4 +224,59 @@ describe('CollaborationManager syncNodes', () => { expect(finalVariables).toEqual(['a', 'b']) expect(getVariableObject(finalNode!, 'b')).toBeDefined() }) + + it('synchronizes prompt_template list updates across collaborators', () => { + const promptManager = new CollaborationManager() + const doc = new LoroDoc() + ;(promptManager as any).doc = doc + ;(promptManager as any).nodesMap = doc.getMap('nodes') + ;(promptManager as any).edgesMap = doc.getMap('edges') + + const baseTemplate = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'avc', + }, + ] + + const baseNode = createLLMNodeSnapshot(baseTemplate) + ;(promptManager as any).syncNodes([], [deepClone(baseNode)]) + + const updatedTemplates = [ + ...baseTemplate, + { + id: 'user-1', + role: 'user', + text: 'hello world', + }, + ] + + const updatedNode = createLLMNodeSnapshot(updatedTemplates) + ;(promptManager as any).syncNodes([deepClone(baseNode)], [deepClone(updatedNode)]) + + const stored = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + expect(stored).toBeDefined() + + const storedTemplates = getPromptTemplates(stored!) + expect(storedTemplates).toHaveLength(2) + expect(storedTemplates[0]).toEqual(baseTemplate[0]) + expect(storedTemplates[1]).toEqual(updatedTemplates[1]) + + const editedTemplates = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'updated system prompt', + }, + ] + const editedNode = createLLMNodeSnapshot(editedTemplates) + + ;(promptManager as any).syncNodes([deepClone(updatedNode)], [deepClone(editedNode)]) + + const final = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + const finalTemplates = getPromptTemplates(final!) + expect(finalTemplates).toHaveLength(1) + expect(finalTemplates[0].text).toBe('updated system prompt') + }) }) diff --git a/web/app/components/workflow/collaboration/core/collaboration-manager.ts b/web/app/components/workflow/collaboration/core/collaboration-manager.ts index b6047cae0e..d1ebbf119b 100644 --- a/web/app/components/workflow/collaboration/core/collaboration-manager.ts +++ b/web/app/components/workflow/collaboration/core/collaboration-manager.ts @@ -77,6 +77,16 @@ export class CollaborationManager { return typeof list.getAttached === 'function' ? list.getAttached() ?? list : list } + private ensurePromptTemplateList(nodeContainer: LoroMap): LoroList { + const dataContainer = this.ensureDataContainer(nodeContainer) + let list = dataContainer.get('prompt_template') as any + + if (!list || typeof list.kind !== 'function' || list.kind() !== 'List') + list = dataContainer.setContainer('prompt_template', new LoroList()) + + return typeof list.getAttached === 'function' ? list.getAttached() ?? list : list + } + private exportNode(nodeId: string): Node { const container = this.getNodeContainer(nodeId) const json = container.toJSON() as any @@ -139,6 +149,8 @@ export class CollaborationManager { if (key === 'variables') this.syncVariables(container, Array.isArray(value) ? value : []) + else if (key === 'prompt_template') + this.syncPromptTemplate(container, Array.isArray(value) ? value : []) else dataContainer.set(key, cloneDeep(value)) }) @@ -150,6 +162,8 @@ export class CollaborationManager { if (key === 'variables') dataContainer.delete('variables') + else if (key === 'prompt_template') + dataContainer.delete('prompt_template') else dataContainer.delete(key) @@ -183,6 +197,28 @@ export class CollaborationManager { } } + private syncPromptTemplate(nodeContainer: LoroMap, desired: any[]): void { + const list = this.ensurePromptTemplateList(nodeContainer) + const current = list.toJSON() as any[] + const target = Array.isArray(desired) ? desired : [] + const minLength = Math.min(current.length, target.length) + + for (let i = 0; i < minLength; i += 1) { + if (!isEqual(current[i], target[i])) { + list.delete(i, 1) + list.insert(i, cloneDeep(target[i])) + } + } + + if (current.length > target.length) { + list.delete(target.length, current.length - target.length) + } + else if (target.length > current.length) { + for (let i = current.length; i < target.length; i += 1) + list.insert(i, cloneDeep(target[i])) + } + } + private getNodePanelPresenceSnapshot(): NodePanelPresenceMap { const snapshot: NodePanelPresenceMap = {} Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {