From f72ed4898cdcef32cd904263a22c3e0043757fa7 Mon Sep 17 00:00:00 2001 From: Stream Date: Fri, 22 Aug 2025 14:57:27 +0800 Subject: [PATCH] refactor: refactor from ChatflowHistoryService and ChatflowMemoryService --- api/core/llm_generator/llm_generator.py | 38 +- api/core/llm_generator/prompts.py | 15 + api/core/memory/entities.py | 9 +- api/core/workflow/nodes/llm/node.py | 4 +- api/services/chatflow_history_service.py | 127 ++----- api/services/chatflow_memory_service.py | 427 +++++++---------------- 6 files changed, 220 insertions(+), 400 deletions(-) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 64fc3a3e80..e3dc5f4e56 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import Optional, cast +from typing import Optional, cast, Mapping import json_repair @@ -16,8 +16,9 @@ from core.llm_generator.prompts import ( LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, MEMORY_UPDATE_PROMPT, ) +from core.memory.entities import MemoryBlock, MemoryBlockSpec from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage @@ -572,3 +573,36 @@ class LLMGenerator: except Exception as e: logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) return {"error": f"An unexpected error occurred: {str(e)}"} + + @staticmethod + def update_memory_block( + tenant_id: str, + visible_history: Mapping[str, str], + memory_block: MemoryBlock, + memory_spec: MemoryBlockSpec + ) -> str: + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + provider=memory_spec.model.provider, + model=memory_spec.model.name, + model_type=ModelType.LLM, + ) + formatted_history = "" + for sender, message in visible_history.items(): + formatted_history += f"{sender}: {message}\n" + formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format( + inputs={ + "formatted_history": formatted_history, + "current_value": memory_block.value, + "instruction": memory_spec.instruction, + } + ) + llm_result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=[UserPromptMessage(content=formatted_prompt)], + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ) + ) + return llm_result.message.get_text_content() diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index e38828578a..710ffe54f2 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex You should edit the prompt according to the IDEAL OUTPUT.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" + +MEMORY_UPDATE_PROMPT = """ +Based on the following conversation history, update the memory content: + +Conversation history: +{{formatted_history}} + +Current memory: +{{current_value}} + +Update instruction: +{{instruction}} + +Please output only the updated memory content, no other text like greeting: +""" # noqa: E501 diff --git a/api/core/memory/entities.py b/api/core/memory/entities.py index 175df321fc..f4faf44160 100644 --- a/api/core/memory/entities.py +++ b/api/core/memory/entities.py @@ -1,9 +1,12 @@ +from datetime import datetime from enum import Enum from typing import Any, Optional from uuid import uuid4 from pydantic import BaseModel, Field +from core.app.app_config.entities import ModelConfig + class MemoryScope(str, Enum): """Memory scope determined by node_id field""" @@ -42,7 +45,7 @@ class MemoryBlockSpec(BaseModel): update_turns: int = Field(gt=0, description="Number of turns between updates") preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve") schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode") - model: Optional[dict[str, Any]] = Field(default=None, description="Model configuration for memory updates") + model: ModelConfig = Field(description="Model configuration for memory updates") end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users") end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users") @@ -69,8 +72,8 @@ class MemoryBlock(BaseModel): app_id: str # None=global(future), str=app-specific conversation_id: Optional[str] = None # None=persistent, str=session node_id: Optional[str] = None # None=app-scope, str=node-scope - created_at: Optional[str] = None - updated_at: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None @property def is_global(self) -> bool: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 07b725899d..1ac66bc0f3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1147,9 +1147,9 @@ class LLMNode(BaseNode): ChatflowMemoryService.update_node_memory_if_needed( tenant_id=self.tenant_id, app_id=self.app_id, - memory_block_spec=memory_block_spec, node_id=self.node_id, - llm_output=llm_output, + conversation_id=conversation_id, + memory_block_spec=memory_block_spec, variable_pool=variable_pool, is_draft=is_draft ) diff --git a/api/services/chatflow_history_service.py b/api/services/chatflow_history_service.py index 828758ee85..3612daed3d 100644 --- a/api/services/chatflow_history_service.py +++ b/api/services/chatflow_history_service.py @@ -1,7 +1,7 @@ import json import time from collections.abc import Sequence -from typing import Literal, Optional, overload +from typing import Literal, Optional, overload, MutableMapping from sqlalchemy import Row, Select, and_, func, select from sqlalchemy.orm import Session @@ -17,15 +17,6 @@ from models.chatflow_memory import ChatflowConversation, ChatflowMessage class ChatflowHistoryService: - """ - Service layer for managing chatflow conversation history. - - This unified service handles all chatflow memory operations: - - Reading visible chat history with version control - - Saving messages to append-only table - - Managing visible_count metadata - - Supporting both app-level and node-level scoping - """ @staticmethod def get_visible_chat_history( @@ -35,18 +26,7 @@ class ChatflowHistoryService: node_id: Optional[str] = None, max_visible_count: Optional[int] = None ) -> Sequence[PromptMessage]: - """ - Get visible chat history based on metadata visible_count. - - Args: - conversation_id: Original conversation ID - node_id: None for app-level, specific node_id for node-level - max_visible_count: Override visible_count for memory update operations - - Returns: - Sequence of PromptMessage objects in chronological order (oldest first) - """ - with db.session() as session: + with Session(db.engine) as session: chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation( session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False ) @@ -54,79 +34,19 @@ class ChatflowHistoryService: if not chatflow_conv: return [] - # Parse metadata - metadata_dict = json.loads(chatflow_conv.conversation_metadata) - metadata = ChatflowConversationMetadata.model_validate(metadata_dict) + metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata) + visible_count: int = max_visible_count or metadata.visible_count - # Determine the actual number of messages to return - target_visible_count = max_visible_count if max_visible_count is not None else metadata.visible_count - - # Fetch all messages (handle versioning) - msg_stmt = select(ChatflowMessage).where( + stmt = select(ChatflowMessage).where( ChatflowMessage.conversation_id == chatflow_conv.id ).order_by(ChatflowMessage.index.asc(), ChatflowMessage.version.desc()) - - all_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(msg_stmt).all() - - # Filter in memory: keep only the latest version for each index - latest_messages_by_index: dict[int, ChatflowMessage] = {} - for msg_row in all_messages: - msg = msg_row[0] - index = msg.index - - if index not in latest_messages_by_index or msg.version > latest_messages_by_index[index].version: - latest_messages_by_index[index] = msg - - # Sort by index and take the latest target_visible_count messages - sorted_messages = sorted(latest_messages_by_index.values(), key=lambda m: m.index, reverse=True) - visible_messages = sorted_messages[:target_visible_count] - - # Convert to PromptMessage and restore correct order (oldest first) - prompt_messages: list[PromptMessage] = [] - for msg in reversed(visible_messages): # Restore chronological order (index ascending) - data = json.loads(msg.data) - role = data.get('role', 'user') - content = data.get('content', '') - - if role == 'user': - prompt_messages.append(UserPromptMessage(content=content)) - elif role == 'assistant': - prompt_messages.append(AssistantPromptMessage(content=content)) - - return prompt_messages - - @staticmethod - def get_app_visible_chat_history( - app_id: str, - conversation_id: str, - tenant_id: str, - max_visible_count: Optional[int] = None - ) -> Sequence[PromptMessage]: - """Get visible chat history for app level.""" - return ChatflowHistoryService.get_visible_chat_history( - conversation_id=conversation_id, - app_id=app_id, - tenant_id=tenant_id, - node_id=None, # App level - max_visible_count=max_visible_count - ) - - @staticmethod - def get_node_visible_chat_history( - node_id: str, - conversation_id: str, - app_id: str, - tenant_id: str, - max_visible_count: Optional[int] = None - ) -> Sequence[PromptMessage]: - """Get visible chat history for a specific node.""" - return ChatflowHistoryService.get_visible_chat_history( - conversation_id=conversation_id, - app_id=app_id, - tenant_id=tenant_id, - node_id=node_id, - max_visible_count=max_visible_count - ) + raw_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(stmt).all() + sorted_messages = ChatflowHistoryService._filter_latest_messages( + [it[0] for it in raw_messages] + ) + visible_count = min(visible_count, len(sorted_messages)) + visible_messages = sorted_messages[-visible_count:] + return [PromptMessage.model_validate_json(it.data) for it in visible_messages] @staticmethod def save_message( @@ -136,13 +56,7 @@ class ChatflowHistoryService: tenant_id: str, node_id: Optional[str] = None ) -> None: - """ - Save a message to the append-only chatflow_messages table. - - Args: - node_id: None for app-level, specific node_id for node-level - """ - with db.session() as session: + with Session(db.engine) as session: chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation( session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True ) @@ -216,7 +130,7 @@ class ChatflowHistoryService: """ Save a new version of an existing message (for message editing scenarios). """ - with db.session() as session: + with Session(db.engine) as session: chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation( session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True ) @@ -270,7 +184,7 @@ class ChatflowHistoryService: # Update node-specific visible_count ChatflowHistoryService.update_visible_count(conv_id, "node-123", 8, app_id, tenant_id) """ - with db.session() as session: + with Session(db.engine) as session: chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation( session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True ) @@ -281,6 +195,17 @@ class ChatflowHistoryService: session.commit() + @staticmethod + def _filter_latest_messages(raw_messages: Sequence[ChatflowMessage]) -> Sequence[ChatflowMessage]: + index_to_message: MutableMapping[int, ChatflowMessage] = {} + for msg in raw_messages: + index = msg.index + if index not in index_to_message or msg.version > index_to_message[index].version: + index_to_message[index] = msg + + sorted_messages = sorted(index_to_message.values(), key=lambda m: m.index) + return sorted_messages + @overload @staticmethod def _get_or_create_chatflow_conversation( diff --git a/api/services/chatflow_memory_service.py b/api/services/chatflow_memory_service.py index dc06397da1..aefd4f230c 100644 --- a/api/services/chatflow_memory_service.py +++ b/api/services/chatflow_memory_service.py @@ -13,7 +13,6 @@ from core.memory.entities import ( MemoryBlockWithVisibility, MemoryScheduleMode, MemoryScope, - MemoryStrategy, MemoryTerm, ) from core.memory.errors import MemorySyncTimeoutError @@ -24,7 +23,9 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models import App from models.chatflow_memory import ChatflowMemoryVariable +from models.workflow import WorkflowDraftVariable from services.chatflow_history_service import ChatflowHistoryService +from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -42,11 +43,6 @@ def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str: return f"memory_sync_update:{app_id}:{conversation_id}" class ChatflowMemoryService: - """ - Memory service class with only static methods. - All methods are static and do not require instantiation. - """ - @staticmethod def get_persistent_memories(app: App) -> Sequence[MemoryBlockWithVisibility]: stmt = select(ChatflowMemoryVariable).where( @@ -56,7 +52,7 @@ class ChatflowMemoryService: ChatflowMemoryVariable.conversation_id == None ) ) - with db.session() as session: + with Session(db.engine) as session: db_results = session.execute(stmt).all() return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results]) @@ -69,94 +65,38 @@ class ChatflowMemoryService: ChatflowMemoryVariable.conversation_id == conversation_id ) ) - with db.session() as session: + with Session(db.engine) as session: db_results = session.execute(stmt).all() return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results]) @staticmethod - def get_memory(memory_id: str, tenant_id: str, - app_id: Optional[str] = None, - conversation_id: Optional[str] = None, - node_id: Optional[str] = None) -> Optional[MemoryBlock]: - """Get single memory by ID""" - stmt = select(ChatflowMemoryVariable).where( - and_( - ChatflowMemoryVariable.memory_id == memory_id, - ChatflowMemoryVariable.tenant_id == tenant_id - ) - ) - - if app_id: - stmt = stmt.where(ChatflowMemoryVariable.app_id == app_id) - if conversation_id: - stmt = stmt.where(ChatflowMemoryVariable.conversation_id == conversation_id) - if node_id: - stmt = stmt.where(ChatflowMemoryVariable.node_id == node_id) - - with db.session() as session: - result = session.execute(stmt).first() - if result: - return MemoryBlock.model_validate(result[0].__dict__) - return None - - @staticmethod - def save_memory(memory: MemoryBlock, tenant_id: str, variable_pool: VariablePool, is_draft: bool = False) -> None: - """Save or update memory with draft mode support""" - + def save_memory(memory: MemoryBlock, tenant_id: str, variable_pool: VariablePool, is_draft: bool) -> None: key = f"{memory.node_id}:{memory.memory_id}" if memory.node_id else memory.memory_id variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value) - stmt = select(ChatflowMemoryVariable).where( - and_( - ChatflowMemoryVariable.memory_id == memory.memory_id, - ChatflowMemoryVariable.tenant_id == tenant_id - ) - ) - with db.session() as session: - existing = session.execute(stmt).first() - if existing: - # Update existing - for key, value in memory.model_dump(exclude_unset=True).items(): - if hasattr(existing[0], key): - setattr(existing[0], key, value) - else: - # Create new - new_memory = ChatflowMemoryVariable( - tenant_id=tenant_id, - **memory.model_dump(exclude={'id'}) - ) - session.add(new_memory) + session.merge(ChatflowMemoryService._to_chatflow_memory_variable(memory)) session.commit() - # In draft mode, also write to workflow_draft_variables if is_draft: - from models.workflow import WorkflowDraftVariable - from services.workflow_draft_variable_service import WorkflowDraftVariableService with Session(bind=db.engine) as session: draft_var_service = WorkflowDraftVariableService(session) - - # Try to get existing variables existing_vars = draft_var_service.get_draft_variables_by_selectors( app_id=memory.app_id, selectors=[['memory_block', memory.memory_id]] ) - if existing_vars: - # Update existing draft variable draft_var = existing_vars[0] draft_var.value = memory.value else: - # Create new draft variable draft_var = WorkflowDraftVariable.new_memory_block_variable( app_id=memory.app_id, memory_id=memory.memory_id, name=memory.name, value=memory.value, - description=f"Memory block: {memory.name}" + description="" ) session.add(draft_var) - session.commit() @staticmethod @@ -164,104 +104,66 @@ class ChatflowMemoryService: tenant_id: str, app_id: str, conversation_id: Optional[str] = None, node_id: Optional[str] = None, - is_draft: bool = False) -> list[MemoryBlock]: - """Get runtime memory values based on MemoryBlockSpecs with draft mode support""" - from models.enums import DraftVariableType + is_draft: bool = False) -> Sequence[MemoryBlock]: + return [ChatflowMemoryService.get_memory_by_spec( + spec, tenant_id, app_id, conversation_id, node_id, is_draft + ) for spec in memory_block_specs] - if not memory_block_specs: - return [] - - # In draft mode, prefer reading from workflow_draft_variables - if is_draft: - # Try reading from the draft variables table - from services.workflow_draft_variable_service import WorkflowDraftVariableService - with Session(bind=db.engine) as session: + @staticmethod + def get_memory_by_spec(spec: MemoryBlockSpec, + tenant_id: str, app_id: str, + conversation_id: Optional[str] = None, + node_id: Optional[str] = None, + is_draft: bool = False) -> MemoryBlock: + with (Session(bind=db.engine) as session): + if is_draft: draft_var_service = WorkflowDraftVariableService(session) - - # Build selector list - selectors = [['memory_block', spec.id] for spec in memory_block_specs] - - # Fetch draft variables + selector = [MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"]\ + if node_id else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id] draft_vars = draft_var_service.get_draft_variables_by_selectors( app_id=app_id, - selectors=selectors + selectors=[selector] ) - - # If draft variables exist, prefer using them if draft_vars: - spec_by_id = {spec.id: spec for spec in memory_block_specs} - draft_memories = [] - - for draft_var in draft_vars: - if draft_var.node_id == DraftVariableType.MEMORY_BLOCK: - spec = spec_by_id.get(draft_var.name) - if spec: - memory_block = MemoryBlock( - id=draft_var.id, - memory_id=draft_var.name, - name=spec.name, - value=draft_var.value, - scope=spec.scope, - term=spec.term, - app_id=app_id, - conversation_id='draft', - node_id=node_id - ) - draft_memories.append(memory_block) - - if draft_memories: - return draft_memories - - memory_ids = [spec.id for spec in memory_block_specs] - - stmt = select(ChatflowMemoryVariable).where( - and_( - ChatflowMemoryVariable.memory_id.in_(memory_ids), - ChatflowMemoryVariable.tenant_id == tenant_id, - ChatflowMemoryVariable.app_id == app_id - ) - ) - - if conversation_id: - stmt = stmt.where(ChatflowMemoryVariable.conversation_id == conversation_id) - if node_id: - stmt = stmt.where(ChatflowMemoryVariable.node_id == node_id) - - with db.session() as session: - results = session.execute(stmt).all() - found_memories = {row[0].memory_id: MemoryBlock.model_validate(row[0].__dict__) for row in results} - - # Create MemoryBlock objects for specs that don't have runtime values yet - all_memories = [] - for spec in memory_block_specs: - if spec.id in found_memories: - all_memories.append(found_memories[spec.id]) - else: - # Create default memory with template value following design rules - default_memory = MemoryBlock( - id="", # Will be assigned when saved - memory_id=spec.id, + draft_var = draft_vars[0] + return MemoryBlock( + id=draft_var.id, + memory_id=draft_var.name, name=spec.name, - value=spec.template, + value=draft_var.value, scope=spec.scope, term=spec.term, - # Design rules: - # - app_id=None for global (future), app_id=str for app-specific - app_id=app_id, # Always app-specific for now - # - conversation_id=None for persistent, conversation_id=str for session - conversation_id=conversation_id if spec.term == MemoryTerm.SESSION else None, - # - node_id=None for app-scope, node_id=str for node-scope - node_id=node_id if spec.scope == MemoryScope.NODE else None + app_id=app_id, + conversation_id=conversation_id, + node_id=node_id ) - all_memories.append(default_memory) - - return all_memories + stmt = select(ChatflowMemoryVariable).where( + and_( + ChatflowMemoryVariable.memory_id == spec.id, + ChatflowMemoryVariable.tenant_id == tenant_id, + ChatflowMemoryVariable.app_id == app_id, + ChatflowMemoryVariable.node_id == node_id, + ChatflowMemoryVariable.conversation_id == conversation_id + ) + ) + result = session.execute(stmt).scalar() + if result: + return ChatflowMemoryService._to_memory_block(result) + return MemoryBlock( + id="", # Will be assigned when saved + memory_id=spec.id, + name=spec.name, + value=spec.template, + scope=spec.scope, + term=spec.term, + app_id=app_id, + conversation_id=conversation_id, + node_id=node_id + ) @staticmethod def get_app_memories_by_workflow(workflow, tenant_id: str, - conversation_id: Optional[str] = None) -> list[MemoryBlock]: - """Get app-scoped memories based on workflow configuration""" - from core.memory.entities import MemoryScope + conversation_id: Optional[str] = None) -> Sequence[MemoryBlock]: app_memory_specs = [spec for spec in workflow.memory_blocks if spec.scope == MemoryScope.APP] return ChatflowMemoryService.get_memories_by_specs( @@ -272,7 +174,7 @@ class ChatflowMemoryService: ) @staticmethod - def get_node_memories_by_workflow(workflow, node_id: str, tenant_id: str) -> list[MemoryBlock]: + def get_node_memories_by_workflow(workflow, node_id: str, tenant_id: str) -> Sequence[MemoryBlock]: """Get node-scoped memories based on workflow configuration""" from core.memory.entities import MemoryScope @@ -287,72 +189,22 @@ class ChatflowMemoryService: node_id=node_id ) - # Core Memory Orchestration features - @staticmethod - def update_memory_if_needed(tenant_id: str, app_id: str, - memory_block_spec: MemoryBlockSpec, - conversation_id: str, - variable_pool: VariablePool, - is_draft: bool = False) -> bool: - """Update app-level memory if conditions are met - - Args: - tenant_id: Tenant ID - app_id: Application ID - memory_block_spec: Memory block specification - conversation_id: Conversation ID - variable_pool: Variable pool for context - is_draft: Whether in draft mode - """ - if not ChatflowMemoryService._should_update_memory( - tenant_id, app_id, memory_block_spec, conversation_id - ): - return False - - if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC: - # Sync mode: will be processed in batch after the App run completes - # This only marks the need; actual update happens in _update_app_memory_after_run - return True - else: - # Async mode: submit asynchronous update immediately - ChatflowMemoryService._submit_async_memory_update( - tenant_id, app_id, memory_block_spec, conversation_id, variable_pool, is_draft - ) - return True - - @staticmethod - def update_node_memory_if_needed(tenant_id: str, app_id: str, - memory_block_spec: MemoryBlockSpec, - node_id: str, llm_output: str, - variable_pool: VariablePool, - is_draft: bool = False) -> bool: - """Update node-level memory after LLM execution - - Args: - tenant_id: Tenant ID - app_id: Application ID - memory_block_spec: Memory block specification - node_id: Node ID - llm_output: LLM output content - variable_pool: Variable pool for context - is_draft: Whether in draft mode - """ + def update_node_memory_if_needed( + tenant_id: str, + app_id: str, + node_id: str, + conversation_id: str, + memory_block_spec: MemoryBlockSpec, + variable_pool: VariablePool, + is_draft: bool + ) -> bool: + """Update node-level memory after LLM execution""" conversation_id_segment = variable_pool.get(('sys', 'conversation_id')) if not conversation_id_segment: return False conversation_id = conversation_id_segment.value - # Save LLM output to node conversation history - assistant_message = AssistantPromptMessage(content=llm_output) - ChatflowHistoryService.save_node_message( - prompt_message=assistant_message, - node_id=node_id, - conversation_id=str(conversation_id), - app_id=app_id, - tenant_id=tenant_id - ) - if not ChatflowMemoryService._should_update_memory( tenant_id, app_id, memory_block_spec, str(conversation_id), node_id ): @@ -372,6 +224,57 @@ class ChatflowMemoryService: ) return True + @staticmethod + def _get_memory_from_chatflow_table(memory_id: str, tenant_id: str, + app_id: Optional[str] = None, + conversation_id: Optional[str] = None, + node_id: Optional[str] = None) -> Optional[MemoryBlock]: + stmt = select(ChatflowMemoryVariable).where( + and_( + ChatflowMemoryVariable.app_id == app_id, + ChatflowMemoryVariable.memory_id == memory_id, + ChatflowMemoryVariable.tenant_id == tenant_id, + ChatflowMemoryVariable.conversation_id == conversation_id, + ChatflowMemoryVariable.node_id == node_id + ) + ) + + with db.session() as session: + result = session.execute(stmt).first() + return ChatflowMemoryService._to_memory_block(result[0]) if result else None + + @staticmethod + def _to_memory_block(entity: ChatflowMemoryVariable) -> MemoryBlock: + scope = MemoryScope(entity.scope) if not isinstance(entity.scope, MemoryScope) else entity.scope + term = MemoryTerm(entity.term) if not isinstance(entity.term, MemoryTerm) else entity.term + return MemoryBlock( + id=entity.id, + memory_id=entity.memory_id, + name=entity.name, + value=entity.value, + scope=scope, + term=term, + app_id=cast(str, entity.app_id), # It's supposed to be not nullable for now + conversation_id=entity.conversation_id, + node_id=entity.node_id, + created_at=entity.created_at, + updated_at=entity.updated_at, + ) + + @staticmethod + def _to_chatflow_memory_variable(memory_block: MemoryBlock) -> ChatflowMemoryVariable: + return ChatflowMemoryVariable( + id=memory_block.id, + node_id=memory_block.node_id, + memory_id=memory_block.memory_id, + name=memory_block.name, + value=memory_block.value, + scope=memory_block.scope, + term=memory_block.term, + app_id=memory_block.app_id, + conversation_id=memory_block.conversation_id, + ) + @staticmethod def _with_visibility( app: App, @@ -400,8 +303,7 @@ class ChatflowMemoryService: memory_block_spec: MemoryBlockSpec, conversation_id: str, node_id: Optional[str] = None) -> bool: """Check if memory should be updated based on strategy""" - if memory_block_spec.strategy != MemoryStrategy.ON_TURNS: - return False + # Currently, `memory_block_spec.strategy != MemoryStrategy.ON_TURNS` is not possible, but possible in the future # Check turn count turn_key = f"memory_turn_count:{tenant_id}:{app_id}:{conversation_id}" @@ -428,7 +330,7 @@ class ChatflowMemoryService: # Execute update asynchronously using thread thread = threading.Thread( - target=ChatflowMemoryService._update_single_memory, + target=ChatflowMemoryService._update_app_single_memory, kwargs={ 'tenant_id': tenant_id, 'app_id': app_id, @@ -492,28 +394,18 @@ class ChatflowMemoryService: tenant_id: str, app_id: str, node_id: str, llm_output: str, variable_pool: VariablePool, is_draft: bool = False): - """Execute node memory update""" - try: - # Call existing _perform_memory_update method here - ChatflowMemoryService._perform_memory_update( - tenant_id=tenant_id, - app_id=app_id, - memory_block_spec=memory_block_spec, - conversation_id=str(variable_pool.get(('sys', 'conversation_id'))), - variable_pool=variable_pool, - node_id=node_id, - is_draft=is_draft - ) - except Exception as e: - logger.exception( - "Failed to update node memory %s for node %s", - memory_block_spec.id, - node_id, - exc_info=e - ) + ChatflowMemoryService._perform_memory_update( + tenant_id=tenant_id, + app_id=app_id, + memory_block_spec=memory_block_spec, + conversation_id=str(variable_pool.get(('sys', 'conversation_id'))), + variable_pool=variable_pool, + node_id=node_id, + is_draft=is_draft + ) @staticmethod - def _update_single_memory(*, tenant_id: str, app_id: str, + def _update_app_single_memory(*, tenant_id: str, app_id: str, memory_block_spec: MemoryBlockSpec, conversation_id: str, variable_pool: VariablePool, @@ -535,62 +427,26 @@ class ChatflowMemoryService: conversation_id: str, variable_pool: VariablePool, node_id: Optional[str] = None, is_draft: bool = False): - """Perform the actual memory update using LLM - - Args: - tenant_id: Tenant ID - app_id: Application ID - memory_block_spec: Memory block specification - conversation_id: Conversation ID - variable_pool: Variable pool for context - node_id: Optional node ID for node-level memory updates - is_draft: Whether in draft mode - """ - # Get conversation history + """Perform the actual memory update using LLM""" history = ChatflowHistoryService.get_visible_chat_history( conversation_id=conversation_id, app_id=app_id, tenant_id=tenant_id, - node_id=node_id, # Pass node_id, if None then get app-level history - max_visible_count=memory_block_spec.preserved_turns + node_id=node_id, ) # Get current memory value - current_memory = ChatflowMemoryService.get_memory( + current_memory = ChatflowMemoryService._get_memory_from_chatflow_table( memory_id=memory_block_spec.id, tenant_id=tenant_id, app_id=app_id, - conversation_id=conversation_id if memory_block_spec.term == MemoryTerm.SESSION else None, + conversation_id=conversation_id, node_id=node_id ) current_value = current_memory.value if current_memory else memory_block_spec.template - # Build update prompt - adjust wording based on whether there's a node_id - context_type = "Node conversation history" if node_id else "Conversation history" - memory_update_prompt = f""" - Based on the following {context_type}, update the memory content: - Current memory: {current_value} - - {context_type}: - {[msg.content for msg in history]} - - Update instruction: {memory_block_spec.instruction} - - Please output the updated memory content: - """ - - # Invoke LLM to update memory - extracted as a separate method - updated_value = ChatflowMemoryService._invoke_llm_for_memory_update( - tenant_id, - memory_block_spec, - memory_update_prompt, - current_value - ) - - if updated_value is None: - return # LLM invocation failed # Save updated memory updated_memory = MemoryBlock( @@ -720,23 +576,10 @@ class ChatflowMemoryService: @staticmethod def update_app_memory_after_run(workflow, conversation_id: str, variable_pool: VariablePool, is_draft: bool = False): - """Update app-level memory after run completion - - Args: - workflow: Workflow object - conversation_id: Conversation ID - variable_pool: Variable pool - is_draft: Whether in draft mode - """ - from core.memory.entities import MemoryScope - - memory_blocks = workflow.memory_blocks - - # Separate sync and async memory blocks + """Update app-level memory after run completion""" sync_blocks = [] async_blocks = [] - - for block in memory_blocks: + for block in workflow.memory_blocks: if block.scope == MemoryScope.APP: if block.update_mode == "sync": sync_blocks.append(block) @@ -805,7 +648,7 @@ class ChatflowMemoryService: futures = [] for block in sync_blocks: future = executor.submit( - ChatflowMemoryService._update_single_memory, + ChatflowMemoryService._update_app_single_memory, tenant_id=workflow.tenant_id, app_id=workflow.app_id, memory_block_spec=block,