diff --git a/api/core/memory/README.md b/api/core/memory/README.md new file mode 100644 index 0000000000..ba8f743125 --- /dev/null +++ b/api/core/memory/README.md @@ -0,0 +1,434 @@ +# Memory Module + +This module provides memory management for LLM conversations, enabling context retention across dialogue turns. + +## Overview + +The memory module contains two types of memory implementations: + +1. **TokenBufferMemory** - Conversation-level memory (existing) +2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**) + +> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode). +> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow. +> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory. + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Memory Architecture │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────-┐ │ +│ │ TokenBufferMemory │ │ +│ │ Scope: Conversation │ │ +│ │ Storage: Database (Message table) │ │ +│ │ Key: conversation_id │ │ +│ └─────────────────────────────────────────────────────────────────────-┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────-┐ │ +│ │ NodeTokenBufferMemory │ │ +│ │ Scope: Node within Conversation │ │ +│ │ Storage: Object Storage (JSON file) │ │ +│ │ Key: (app_id, conversation_id, node_id) │ │ +│ └─────────────────────────────────────────────────────────────────────-┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## TokenBufferMemory (Existing) + +### Purpose + +`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context. + +### Key Features + +- **Conversation-scoped**: All messages within a conversation are candidates +- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios) +- **Token-limited**: Truncates history to fit within `max_token_limit` +- **File support**: Handles `MessageFile` attachments (images, documents, etc.) + +### Data Flow + +``` +Message Table TokenBufferMemory LLM + │ │ │ + │ SELECT * FROM messages │ │ + │ WHERE conversation_id = ? │ │ + │ ORDER BY created_at DESC │ │ + ├─────────────────────────────────▶│ │ + │ │ │ + │ extract_thread_messages() │ + │ │ │ + │ build_prompt_message_with_files() │ + │ │ │ + │ truncate by max_token_limit │ + │ │ │ + │ │ Sequence[PromptMessage] + │ ├───────────────────────▶│ + │ │ │ +``` + +### Thread Extraction + +When a user regenerates a response, a new thread is created: + +``` +Message A (user) + └── Message A' (assistant) + └── Message B (user) + └── Message B' (assistant) + └── Message A'' (assistant, regenerated) ← New thread + └── Message C (user) + └── Message C' (assistant) +``` + +`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']` + +### Usage + +```python +from core.memory.token_buffer_memory import TokenBufferMemory + +memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) +history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100) +``` + +--- + +## NodeTokenBufferMemory (To Be Implemented) + +### Purpose + +`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history. + +### Use Cases + +1. **Multi-LLM Workflows**: Different LLM nodes need separate context +2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations +3. **Specialized Agents**: Each agent node maintains its own dialogue history + +### Design Decisions + +#### Storage: Object Storage for Messages (No New Database Table) + +| Aspect | Database | Object Storage | +| ------------------------- | -------------------- | ------------------ | +| Cost | High | Low | +| Query Flexibility | High | Low | +| Schema Changes | Migration required | None | +| Consistency with existing | ConversationVariable | File uploads, logs | + +**Decision**: Store message data in object storage, but still use existing database tables for file metadata. + +**What is stored in Object Storage:** + +- Message content (text) +- Message metadata (role, token_count, created_at) +- File references (upload_file_id, tool_file_id, etc.) +- Thread relationships (message_id, parent_message_id) + +**What still requires Database queries:** + +- File reconstruction: When reading node memory, file references are used to query + `UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild + complete `File` objects with storage_key, mime_type, etc. + +**Why this hybrid approach:** + +- No database migration required (no new tables) +- Message data may be large, object storage is cost-effective +- File metadata is already in database, no need to duplicate +- Aligns with existing storage patterns (file uploads, logs) + +#### Storage Key Format + +``` +node_memory/{app_id}/{conversation_id}/{node_id}.json +``` + +#### Data Structure + +```json +{ + "version": 1, + "messages": [ + { + "message_id": "msg-001", + "parent_message_id": null, + "role": "user", + "content": "Analyze this image", + "files": [ + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": "file-uuid-123", + "belongs_to": "user" + } + ], + "token_count": 15, + "created_at": "2026-01-07T10:00:00Z" + }, + { + "message_id": "msg-002", + "parent_message_id": "msg-001", + "role": "assistant", + "content": "This is a landscape image...", + "files": [], + "token_count": 50, + "created_at": "2026-01-07T10:00:01Z" + } + ] +} +``` + +### Thread Support + +Node memory also supports thread extraction (for regeneration scenarios): + +```python +def _extract_thread( + self, + messages: list[NodeMemoryMessage], + current_message_id: str +) -> list[NodeMemoryMessage]: + """ + Extract messages belonging to the thread of current_message_id. + Similar to extract_thread_messages() in TokenBufferMemory. + """ + ... +``` + +### File Handling + +Files are stored as references (not full metadata): + +```python +class NodeMemoryFile(BaseModel): + type: str # image, audio, video, document, custom + transfer_method: str # local_file, remote_url, tool_file + upload_file_id: str | None # for local_file + tool_file_id: str | None # for tool_file + url: str | None # for remote_url + belongs_to: str # user / assistant +``` + +When reading, files are rebuilt using `file_factory.build_from_mapping()`. + +### API Design + +```python +class NodeTokenBufferMemory: + def __init__( + self, + app_id: str, + conversation_id: str, + node_id: str, + model_instance: ModelInstance, + ): + """ + Initialize node-level memory. + + :param app_id: Application ID + :param conversation_id: Conversation ID + :param node_id: Node ID in the workflow + :param model_instance: Model instance for token counting + """ + ... + + def add_messages( + self, + message_id: str, + parent_message_id: str | None, + user_content: str, + user_files: Sequence[File], + assistant_content: str, + assistant_files: Sequence[File], + ) -> None: + """ + Append a dialogue turn (user + assistant) to node memory. + Call this after LLM node execution completes. + + :param message_id: Current message ID (from Message table) + :param parent_message_id: Parent message ID (for thread tracking) + :param user_content: User's text input + :param user_files: Files attached by user + :param assistant_content: Assistant's text response + :param assistant_files: Files generated by assistant + """ + ... + + def get_history_prompt_messages( + self, + current_message_id: str, + tenant_id: str, + max_token_limit: int = 2000, + file_upload_config: FileUploadConfig | None = None, + ) -> Sequence[PromptMessage]: + """ + Retrieve history as PromptMessage sequence. + + :param current_message_id: Current message ID (for thread extraction) + :param tenant_id: Tenant ID (for file reconstruction) + :param max_token_limit: Maximum tokens for history + :param file_upload_config: File upload configuration + :return: Sequence of PromptMessage for LLM context + """ + ... + + def flush(self) -> None: + """ + Persist buffered changes to object storage. + Call this at the end of node execution. + """ + ... + + def clear(self) -> None: + """ + Clear all messages in this node's memory. + """ + ... +``` + +### Data Flow + +``` +Object Storage NodeTokenBufferMemory LLM Node + │ │ │ + │ │◀── get_history_prompt_messages() + │ storage.load(key) │ │ + │◀─────────────────────────────────┤ │ + │ │ │ + │ JSON data │ │ + ├─────────────────────────────────▶│ │ + │ │ │ + │ _extract_thread() │ + │ │ │ + │ _rebuild_files() via file_factory │ + │ │ │ + │ _build_prompt_messages() │ + │ │ │ + │ _truncate_by_tokens() │ + │ │ │ + │ │ Sequence[PromptMessage] │ + │ ├──────────────────────────▶│ + │ │ │ + │ │◀── LLM execution complete │ + │ │ │ + │ │◀── add_messages() │ + │ │ │ + │ storage.save(key, data) │ │ + │◀─────────────────────────────────┤ │ + │ │ │ +``` + +### Integration with LLM Node + +```python +# In LLM Node execution + +# 1. Fetch memory based on mode +if node_data.memory and node_data.memory.mode == MemoryMode.NODE: + # Node-level memory (Chatflow only) + memory = fetch_node_memory( + variable_pool=variable_pool, + app_id=app_id, + node_id=self.node_id, + node_data_memory=node_data.memory, + model_instance=model_instance, + ) +elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION: + # Conversation-level memory (existing behavior) + memory = fetch_memory( + variable_pool=variable_pool, + app_id=app_id, + node_data_memory=node_data.memory, + model_instance=model_instance, + ) +else: + memory = None + +# 2. Get history for context +if memory: + if isinstance(memory, NodeTokenBufferMemory): + history = memory.get_history_prompt_messages( + current_message_id=current_message_id, + tenant_id=tenant_id, + max_token_limit=max_token_limit, + ) + else: # TokenBufferMemory + history = memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + ) + prompt_messages = [*history, *current_messages] +else: + prompt_messages = current_messages + +# 3. Call LLM +response = model_instance.invoke(prompt_messages) + +# 4. Append to node memory (only for NodeTokenBufferMemory) +if isinstance(memory, NodeTokenBufferMemory): + memory.add_messages( + message_id=message_id, + parent_message_id=parent_message_id, + user_content=user_input, + user_files=user_files, + assistant_content=response.content, + assistant_files=response_files, + ) + memory.flush() +``` + +### Configuration + +Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`: + +```python +class MemoryMode(StrEnum): + CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior) + NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only) + +class MemoryConfig(BaseModel): + # Existing fields + role_prefix: RolePrefix | None = None + window: MemoryWindowConfig | None = None + query_prompt_template: str | None = None + + # Memory mode (new) + mode: MemoryMode = MemoryMode.CONVERSATION +``` + +**Mode Behavior:** + +| Mode | Memory Class | Scope | Availability | +| -------------- | --------------------- | ------------------------ | ------------- | +| `conversation` | TokenBufferMemory | Entire conversation | All app modes | +| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only | + +> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should +> fall back to no memory or raise a configuration error. + +--- + +## Comparison + +| Feature | TokenBufferMemory | NodeTokenBufferMemory | +| -------------- | ------------------------ | ------------------------- | +| Scope | Conversation | Node within Conversation | +| Storage | Database (Message table) | Object Storage (JSON) | +| Thread Support | Yes | Yes | +| File Support | Yes (via MessageFile) | Yes (via file references) | +| Token Limit | Yes | Yes | +| Use Case | Standard chat apps | Complex workflows | + +--- + +## Future Considerations + +1. **Cleanup Task**: Add a Celery task to clean up old node memory files +2. **Concurrency**: Consider Redis lock for concurrent node executions +3. **Compression**: Compress large memory files to reduce storage costs +4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory diff --git a/api/core/memory/__init__.py b/api/core/memory/__init__.py new file mode 100644 index 0000000000..4baef1a835 --- /dev/null +++ b/api/core/memory/__init__.py @@ -0,0 +1,15 @@ +from core.memory.base import BaseMemory +from core.memory.node_token_buffer_memory import ( + NodeMemoryData, + NodeMemoryFile, + NodeTokenBufferMemory, +) +from core.memory.token_buffer_memory import TokenBufferMemory + +__all__ = [ + "BaseMemory", + "NodeMemoryData", + "NodeMemoryFile", + "NodeTokenBufferMemory", + "TokenBufferMemory", +] diff --git a/api/core/memory/base.py b/api/core/memory/base.py new file mode 100644 index 0000000000..af6e8eeda3 --- /dev/null +++ b/api/core/memory/base.py @@ -0,0 +1,83 @@ +""" +Base memory interfaces and types. + +This module defines the common protocol for memory implementations. +""" + +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage + + +class BaseMemory(ABC): + """ + Abstract base class for memory implementations. + + Provides a common interface for both conversation-level and node-level memory. + """ + + @abstractmethod + def get_history_prompt_messages( + self, + *, + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> Sequence[PromptMessage]: + """ + Get history prompt messages. + + :param max_token_limit: Maximum tokens for history + :param message_limit: Maximum number of messages + :return: Sequence of PromptMessage for LLM context + """ + pass + + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> str: + """ + Get history prompt as formatted text. + + :param human_prefix: Prefix for human messages + :param ai_prefix: Prefix for assistant messages + :param max_token_limit: Maximum tokens for history + :param message_limit: Maximum number of messages + :return: Formatted history text + """ + from core.model_runtime.entities import ( + PromptMessageRole, + TextPromptMessageContent, + ) + + prompt_messages = self.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=message_limit, + ) + + string_messages = [] + for m in prompt_messages: + if m.role == PromptMessageRole.USER: + role = human_prefix + elif m.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) + + return "\n".join(string_messages) diff --git a/api/core/memory/node_token_buffer_memory.py b/api/core/memory/node_token_buffer_memory.py new file mode 100644 index 0000000000..bc38c953eb --- /dev/null +++ b/api/core/memory/node_token_buffer_memory.py @@ -0,0 +1,353 @@ +""" +Node-level Token Buffer Memory for Chatflow. + +This module provides node-scoped memory within a conversation. +Each LLM node in a workflow can maintain its own independent conversation history. + +Note: This is only available in Chatflow (advanced-chat mode) because it requires +both conversation_id and node_id. + +Design: +- Storage is indexed by workflow_run_id (each execution stores one turn) +- Thread tracking leverages Message table's parent_message_id structure +- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids +""" + +import logging +from collections.abc import Sequence + +from pydantic import BaseModel +from sqlalchemy import select + +from core.file import File, FileTransferMethod +from core.memory.base import BaseMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import Message + +logger = logging.getLogger(__name__) + + +class NodeMemoryFile(BaseModel): + """File reference stored in node memory.""" + + type: str # image, audio, video, document, custom + transfer_method: str # local_file, remote_url, tool_file + upload_file_id: str | None = None + tool_file_id: str | None = None + url: str | None = None + + +class NodeMemoryTurn(BaseModel): + """A single dialogue turn (user + assistant) in node memory.""" + + user_content: str = "" + user_files: list[NodeMemoryFile] = [] + assistant_content: str = "" + assistant_files: list[NodeMemoryFile] = [] + + +class NodeMemoryData(BaseModel): + """Root data structure for node memory storage.""" + + version: int = 1 + # Key: workflow_run_id, Value: dialogue turn + turns: dict[str, NodeMemoryTurn] = {} + + +class NodeTokenBufferMemory(BaseMemory): + """ + Node-level Token Buffer Memory. + + Provides node-scoped memory within a conversation. Each LLM node can maintain + its own independent conversation history, stored in object storage. + + Key design: Thread tracking is delegated to Message table's parent_message_id. + Storage is indexed by workflow_run_id for easy filtering. + + Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json + """ + + def __init__( + self, + app_id: str, + conversation_id: str, + node_id: str, + tenant_id: str, + model_instance: ModelInstance, + ): + """ + Initialize node-level memory. + + :param app_id: Application ID + :param conversation_id: Conversation ID + :param node_id: Node ID in the workflow + :param tenant_id: Tenant ID for file reconstruction + :param model_instance: Model instance for token counting + """ + self.app_id = app_id + self.conversation_id = conversation_id + self.node_id = node_id + self.tenant_id = tenant_id + self.model_instance = model_instance + self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json" + self._data: NodeMemoryData | None = None + self._dirty = False + + def _load(self) -> NodeMemoryData: + """Load data from object storage.""" + if self._data is not None: + return self._data + + try: + raw = storage.load_once(self._storage_key) + self._data = NodeMemoryData.model_validate_json(raw) + except Exception: + # File not found or parse error, start fresh + self._data = NodeMemoryData() + + return self._data + + def _save(self) -> None: + """Save data to object storage.""" + if self._data is not None: + storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8")) + self._dirty = False + + def _file_to_memory_file(self, file: File) -> NodeMemoryFile: + """Convert File object to NodeMemoryFile reference.""" + return NodeMemoryFile( + type=file.type.value if hasattr(file.type, "value") else str(file.type), + transfer_method=( + file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method) + ), + upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None, + tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None, + url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None, + ) + + def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict: + """Convert NodeMemoryFile to mapping for file_factory.""" + mapping: dict = { + "type": memory_file.type, + "transfer_method": memory_file.transfer_method, + } + if memory_file.upload_file_id: + mapping["upload_file_id"] = memory_file.upload_file_id + if memory_file.tool_file_id: + mapping["tool_file_id"] = memory_file.tool_file_id + if memory_file.url: + mapping["url"] = memory_file.url + return mapping + + def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]: + """Rebuild File objects from NodeMemoryFile references.""" + if not memory_files: + return [] + + from factories import file_factory + + files = [] + for mf in memory_files: + try: + mapping = self._memory_file_to_mapping(mf) + file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id) + files.append(file) + except Exception as e: + logger.warning("Failed to rebuild file from memory: %s", e) + continue + return files + + def _build_prompt_message( + self, + role: str, + content: str, + files: list[File], + detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH, + ) -> PromptMessage: + """Build PromptMessage from content and files.""" + from core.file import file_manager + + if not files: + if role == "user": + return UserPromptMessage(content=content) + else: + return AssistantPromptMessage(content=content) + + # Build multimodal content + prompt_contents: list = [] + for file in files: + try: + prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail) + prompt_contents.append(prompt_content) + except Exception as e: + logger.warning("Failed to convert file to prompt content: %s", e) + continue + + prompt_contents.append(TextPromptMessageContent(data=content)) + + if role == "user": + return UserPromptMessage(content=prompt_contents) + else: + return AssistantPromptMessage(content=prompt_contents) + + def _get_thread_workflow_run_ids(self) -> list[str]: + """ + Get workflow_run_ids for the current thread by querying Message table. + + Returns workflow_run_ids in chronological order (oldest first). + """ + # Query messages for this conversation + stmt = ( + select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc()) + ) + messages = db.session.scalars(stmt.limit(500)).all() + + if not messages: + return [] + + # Extract thread messages using existing logic + thread_messages = extract_thread_messages(messages) + + # For newly created message, its answer is temporarily empty, skip it + if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: + thread_messages.pop(0) + + # Reverse to get chronological order, extract workflow_run_ids + workflow_run_ids = [] + for msg in reversed(thread_messages): + if msg.workflow_run_id: + workflow_run_ids.append(msg.workflow_run_id) + + return workflow_run_ids + + def add_messages( + self, + workflow_run_id: str, + user_content: str, + user_files: Sequence[File] | None = None, + assistant_content: str = "", + assistant_files: Sequence[File] | None = None, + ) -> None: + """ + Add a dialogue turn to node memory. + Call this after LLM node execution completes. + + :param workflow_run_id: Current workflow execution ID + :param user_content: User's text input + :param user_files: Files attached by user + :param assistant_content: Assistant's text response + :param assistant_files: Files generated by assistant + """ + data = self._load() + + # Convert files to memory file references + user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])] + assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])] + + # Store the turn indexed by workflow_run_id + data.turns[workflow_run_id] = NodeMemoryTurn( + user_content=user_content, + user_files=user_memory_files, + assistant_content=assistant_content, + assistant_files=assistant_memory_files, + ) + + self._dirty = True + + def get_history_prompt_messages( + self, + *, + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> Sequence[PromptMessage]: + """ + Retrieve history as PromptMessage sequence. + + Thread tracking is handled by querying Message table's parent_message_id structure. + + :param max_token_limit: Maximum tokens for history + :param message_limit: unused, for interface compatibility + :return: Sequence of PromptMessage for LLM context + """ + # message_limit is unused in NodeTokenBufferMemory (uses token limit instead) + _ = message_limit + detail = ImagePromptMessageContent.DETAIL.HIGH + data = self._load() + + if not data.turns: + return [] + + # Get workflow_run_ids for current thread from Message table + thread_workflow_run_ids = self._get_thread_workflow_run_ids() + + if not thread_workflow_run_ids: + return [] + + # Build prompt messages in thread order + prompt_messages: list[PromptMessage] = [] + for wf_run_id in thread_workflow_run_ids: + turn = data.turns.get(wf_run_id) + if not turn: + # This workflow execution didn't have node memory stored + continue + + # Build user message + user_files = self._rebuild_files(turn.user_files) if turn.user_files else [] + user_msg = self._build_prompt_message( + role="user", + content=turn.user_content, + files=user_files, + detail=detail, + ) + prompt_messages.append(user_msg) + + # Build assistant message + assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else [] + assistant_msg = self._build_prompt_message( + role="assistant", + content=turn.assistant_content, + files=assistant_files, + detail=detail, + ) + prompt_messages.append(assistant_msg) + + if not prompt_messages: + return [] + + # Truncate by token limit + try: + current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + while current_tokens > max_token_limit and len(prompt_messages) > 1: + prompt_messages.pop(0) + current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + except Exception as e: + logger.warning("Failed to count tokens for truncation: %s", e) + + return prompt_messages + + def flush(self) -> None: + """ + Persist buffered changes to object storage. + Call this at the end of node execution. + """ + if self._dirty: + self._save() + + def clear(self) -> None: + """Clear all messages in this node's memory.""" + self._data = NodeMemoryData() + self._save() + + def exists(self) -> bool: + """Check if node memory exists in storage.""" + return storage.exists(self._storage_key) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 3ebbb60f85..58ffe04240 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager +from core.memory.base import BaseMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, - PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) @@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory -class TokenBufferMemory: +class TokenBufferMemory(BaseMemory): def __init__( self, conversation: Conversation, @@ -115,10 +115,14 @@ class TokenBufferMemory: return AssistantPromptMessage(content=prompt_message_contents) def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: int | None = None + self, + *, + max_token_limit: int = 2000, + message_limit: int | None = None, ) -> Sequence[PromptMessage]: """ Get history prompt messages. + :param max_token_limit: max token limit :param message_limit: message limit """ @@ -200,44 +204,3 @@ class TokenBufferMemory: curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - - def get_history_prompt_text( - self, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: int | None = None, - ) -> str: - """ - Get history prompt text. - :param human_prefix: human prefix - :param ai_prefix: ai prefix - :param max_token_limit: max token limit - :param message_limit: message limit - :return: - """ - prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) - - string_messages = [] - for m in prompt_messages: - if m.role == PromptMessageRole.USER: - role = human_prefix - elif m.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(m.content, list): - inner_msg = "" - for content in m.content: - if isinstance(content, TextPromptMessageContent): - inner_msg += f"{content.data}\n" - elif isinstance(content, ImagePromptMessageContent): - inner_msg += "[image]\n" - - string_messages.append(f"{role}: {inner_msg.strip()}") - else: - message = f"{role}: {m.content}" - string_messages.append(message) - - return "\n".join(string_messages) diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 7094633093..457800bad2 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,3 +1,4 @@ +from enum import StrEnum from typing import Literal from pydantic import BaseModel @@ -5,6 +6,13 @@ from pydantic import BaseModel from core.model_runtime.entities.message_entities import PromptMessageRole +class MemoryMode(StrEnum): + """Memory mode for LLM nodes.""" + + CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior) + NODE = "node" # Use NodeTokenBufferMemory (Chatflow only) + + class ChatModelMessage(BaseModel): """ Chat Message. @@ -48,3 +56,4 @@ class MemoryConfig(BaseModel): role_prefix: RolePrefix | None = None window: WindowConfig query_prompt_template: str | None = None + mode: MemoryMode = MemoryMode.CONVERSATION diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0c545469bc..aa5c784357 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -8,12 +8,13 @@ from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_entities import QuotaUnit from core.file.models import File -from core.memory.token_buffer_memory import TokenBufferMemory +from core.memory import NodeTokenBufferMemory, TokenBufferMemory +from core.memory.base import BaseMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from core.workflow.enums import SystemVariableKey from core.workflow.nodes.llm.entities import ModelConfig @@ -86,25 +87,56 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance -) -> TokenBufferMemory | None: + variable_pool: VariablePool, + app_id: str, + tenant_id: str, + node_data_memory: MemoryConfig | None, + model_instance: ModelInstance, + node_id: str = "", +) -> BaseMemory | None: + """ + Fetch memory based on configuration mode. + + Returns TokenBufferMemory for conversation mode (default), + or NodeTokenBufferMemory for node mode (Chatflow only). + + :param variable_pool: Variable pool containing system variables + :param app_id: Application ID + :param tenant_id: Tenant ID + :param node_data_memory: Memory configuration + :param model_instance: Model instance for token counting + :param node_id: Node ID in the workflow (required for node mode) + :return: Memory instance or None if not applicable + """ if not node_data_memory: return None - # get conversation id + # Get conversation_id from variable pool (required for both modes in Chatflow) conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) if not isinstance(conversation_id_variable, StringSegment): return None conversation_id = conversation_id_variable.value - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) - conversation = session.scalar(stmt) - if not conversation: + # Return appropriate memory type based on mode + if node_data_memory.mode == MemoryMode.NODE: + # Node-level memory (Chatflow only) + if not node_id: return None - - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - return memory + return NodeTokenBufferMemory( + app_id=app_id, + conversation_id=conversation_id, + node_id=node_id, + tenant_id=tenant_id, + model_instance=model_instance, + ) + else: + # Conversation-level memory (default) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 04e2802191..bbd6c92e75 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -14,7 +14,8 @@ from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.memory.token_buffer_memory import TokenBufferMemory +from core.memory.base import BaseMemory +from core.memory.node_token_buffer_memory import NodeTokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( ImagePromptMessageContent, @@ -206,8 +207,10 @@ class LLMNode(Node[LLMNodeData]): memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, + tenant_id=self.tenant_id, node_data_memory=self.node_data.memory, model_instance=model_instance, + node_id=self._node_id, ) query: str | None = None @@ -299,12 +302,41 @@ class LLMNode(Node[LLMNodeData]): "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, + "context": self._build_context(prompt_messages, clean_text, model_config.mode), } if structured_output: outputs["structured_output"] = structured_output.structured_output if self._file_outputs: outputs["files"] = ArrayFileSegment(value=self._file_outputs) + # Write to Node Memory if in node memory mode + if isinstance(memory, NodeTokenBufferMemory): + # Get workflow_run_id as the key for this execution + workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]) + workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else "" + + if workflow_run_id: + # Resolve the query template to get actual user content + # query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}" + actual_query = variable_pool.convert_template(query or "").text + + # Get user files from sys.files + user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES]) + user_files: list[File] = [] + if isinstance(user_files_var, ArrayFileSegment): + user_files = list(user_files_var.value) + elif isinstance(user_files_var, FileSegment): + user_files = [user_files_var.value] + + memory.add_messages( + workflow_run_id=workflow_run_id, + user_content=actual_query, + user_files=user_files, + assistant_content=clean_text, + assistant_files=self._file_outputs, + ) + memory.flush() + # Send final chunk event to indicate streaming is complete yield StreamChunkEvent( selector=[self._node_id, "text"], @@ -564,6 +596,22 @@ class LLMNode(Node[LLMNodeData]): # Separated mode: always return clean text and reasoning_content return clean_text, reasoning_content or "" + @staticmethod + def _build_context( + prompt_messages: Sequence[PromptMessage], + assistant_response: str, + model_mode: str, + ) -> list[dict[str, Any]]: + """ + Build context from prompt messages and assistant response. + Excludes system messages and includes the current LLM response. + """ + context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM] + context_messages.append(AssistantPromptMessage(content=assistant_response)) + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, prompt_messages=context_messages + ) + def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: @@ -776,7 +824,7 @@ class LLMNode(Node[LLMNodeData]): sys_query: str | None = None, sys_files: Sequence["File"], context: str | None = None, - memory: TokenBufferMemory | None = None, + memory: BaseMemory | None = None, model_config: ModelConfigWithCredentialsEntity, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, memory_config: MemoryConfig | None = None, @@ -1335,7 +1383,7 @@ def _calculate_rest_token( def _handle_memory_chat_mode( *, - memory: TokenBufferMemory | None, + memory: BaseMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: @@ -1352,7 +1400,7 @@ def _handle_memory_chat_mode( def _handle_memory_completion_mode( *, - memory: TokenBufferMemory | None, + memory: BaseMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> str: