From a9bae7aafdb560e369173cdc0e605df893ff786e Mon Sep 17 00:00:00 2001 From: Novice Date: Sun, 27 Apr 2025 13:30:53 +0800 Subject: [PATCH] feat: add independent memory --- api/core/memory/__init__.py | 0 api/core/memory/base_memory.py | 62 +++- api/core/memory/model_context_memory.py | 281 +++++++++--------- api/core/memory/token_buffer_memory.py | 45 +-- api/models/workflow.py | 11 + .../nodes/_base/components/memory-config.tsx | 1 + 6 files changed, 206 insertions(+), 194 deletions(-) create mode 100644 api/core/memory/__init__.py diff --git a/api/core/memory/__init__.py b/api/core/memory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/memory/base_memory.py b/api/core/memory/base_memory.py index 259d6d6a59..f17191f44c 100644 --- a/api/core/memory/base_memory.py +++ b/api/core/memory/base_memory.py @@ -1,18 +1,64 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Sequence +from typing import Optional -from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + TextPromptMessageContent, +) -class BaseMemory(ABC): +class BaseMemory: @abstractmethod - def get_history_prompt_messages(self) -> Sequence[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: """ - Get the history prompt messages + Get history prompt messages. + :param max_token_limit: max token limit + :param message_limit: message limit + :return: """ - @abstractmethod - def get_history_prompt_text(self) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ - Get the history prompt text + Get history prompt text. + :param human_prefix: human prefix + :param ai_prefix: ai prefix + :param max_token_limit: max token limit + :param message_limit: message limit + :return: """ + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) + + string_messages = [] + for m in prompt_messages: + if m.role == PromptMessageRole.USER: + role = human_prefix + elif m.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) + + return "\n".join(string_messages) diff --git a/api/core/memory/model_context_memory.py b/api/core/memory/model_context_memory.py index 3dc58e4baf..d4ec9d15c8 100644 --- a/api/core/memory/model_context_memory.py +++ b/api/core/memory/model_context_memory.py @@ -1,26 +1,28 @@ import json from collections.abc import Sequence -from typing import Optional +from typing import Optional, cast +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.file import file_manager +from core.memory.base_memory import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, UserPromptMessage, ) from core.prompt.entities.advanced_prompt_entities import LLMMemoryType from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db -from models.model import Conversation, Message -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from factories import file_factory +from models.model import AppMode, Conversation, Message, MessageFile +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun -class ModelContextMemory: +class ModelContextMemory(BaseMemory): def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None: self.conversation = conversation self.node_id = node_id @@ -34,8 +36,104 @@ class ModelContextMemory: :param max_token_limit: max token limit :param message_limit: message limit """ + thread_messages = list(reversed(self._fetch_thread_messages(message_limit))) + if not thread_messages: + return [] + # Get all required workflow_run_ids + workflow_run_ids = [msg.workflow_run_id for msg in thread_messages] - # fetch limited messages, and return reversed + # Batch query all related WorkflowNodeExecution records + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.workflow_run_id.in_(workflow_run_ids), + WorkflowNodeExecution.node_id == self.node_id, + WorkflowNodeExecution.status.in_( + [WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION] + ), + ) + .all() + ) + + # Create mapping from workflow_run_id to node_execution + node_execution_map = {ne.workflow_run_id: ne for ne in node_executions} + + # Get the last node_execution + last_node_execution = node_execution_map.get(thread_messages[-1].workflow_run_id) + prompt_messages = self._get_prompt_messages_in_process_data(last_node_execution) + + # Batch query all message-related files + message_ids = [msg.id for msg in thread_messages] + all_files = db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_ids)).all() + + # Create mapping from message_id to files + files_map = {} + for file in all_files: + if file.message_id not in files_map: + files_map[file.message_id] = [] + files_map[file.message_id].append(file) + + for message in thread_messages: + files = files_map.get(message.id, []) + node_execution = node_execution_map.get(message.workflow_run_id) + if node_execution and files: + file_objs, detail = self._handle_file(message, files) + if file_objs: + outputs = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else "" + if not outputs: + continue + if outputs not in [prompt.content for prompt in prompt_messages]: + continue + outputs_index = [prompt.content for prompt in prompt_messages].index(outputs) + prompt_index = outputs_index - 1 + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + content = cast(str, prompt_messages[prompt_index].content) + prompt_message_contents.append(TextPromptMessageContent(data=content)) + for file in file_objs: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) + prompt_message_contents.append(prompt_message) + prompt_messages[prompt_index].content = prompt_message_contents + return prompt_messages + + def _get_prompt_messages_in_process_data( + self, + node_execution: WorkflowNodeExecution, + ) -> list[PromptMessage]: + """ + Get prompt messages in process data. + :param node_execution: node execution + :return: prompt messages + """ + prompt_messages = [] + if not node_execution.process_data: + return [] + + try: + process_data = json.loads(node_execution.process_data) + if process_data.get("memory_type", "") != LLMMemoryType.INDEPENDENT: + return [] + prompts = process_data.get("prompts", []) + for prompt in prompts: + prompt_content = prompt.get("text", "") + if prompt.get("role", "") == "user": + prompt_messages.append(UserPromptMessage(content=prompt_content)) + elif prompt.get("role", "") == "assistant": + prompt_messages.append(AssistantPromptMessage(content=prompt_content)) + output = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else "" + prompt_messages.append(AssistantPromptMessage(content=output)) + except json.JSONDecodeError: + return [] + return prompt_messages + + def _fetch_thread_messages(self, message_limit: int | None = None) -> list[Message]: + """ + Fetch thread messages. + :param message_limit: message limit + :return: thread messages + """ query = ( db.session.query( Message.id, @@ -59,147 +157,44 @@ class ModelContextMemory: messages = query.limit(message_limit).all() - # instead of all messages from the conversation, we only need to extract messages - # that belong to the thread of last message + # fetch the thread messages thread_messages = extract_thread_messages(messages) # for newly created message, its answer is temporarily empty, we don't need to add it to memory if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: thread_messages.pop(0) - if len(thread_messages) == 0: + if not thread_messages: return [] - last_thread_message = list(reversed(thread_messages))[0] - last_node_execution = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.workflow_run_id == last_thread_message.workflow_run_id, - WorkflowNodeExecution.node_id == self.node_id, - WorkflowNodeExecution.status.in_( - [WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION] - ), - ) - .order_by(WorkflowNodeExecution.created_at.desc()) - .first() - ) - prompt_messages: list[PromptMessage] = [] + return thread_messages - # files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() - # if files: - # file_extra_config = None - # if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - # file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) - # else: - # if message.workflow_run_id: - # workflow_run = ( - # db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - # ) + def _handle_file(self, message: Message, files: list[MessageFile]): + """ + Handle file for memory. + :param message: message + :param files: files + :return: file objects and detail + """ + file_extra_config = None + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + else: + if message.workflow_run_id: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - # if workflow_run and workflow_run.workflow: - # file_extra_config = FileUploadConfigManager.convert( - # workflow_run.workflow.features_dict, is_vision=False - # ) - - # detail = ImagePromptMessageContent.DETAIL.LOW - # if file_extra_config and app_record: - # file_objs = file_factory.build_from_message_files( - # message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config - # ) - # if file_extra_config.image_config and file_extra_config.image_config.detail: - # detail = file_extra_config.image_config.detail - # else: - # file_objs = [] - - # if not file_objs: - # prompt_messages.append(UserPromptMessage(content=message.query)) - # else: - # prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - # prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - # for file in file_objs: - # prompt_message = file_manager.to_prompt_message_content( - # file, - # image_detail_config=detail, - # ) - # prompt_message_contents.append(prompt_message) - - # prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - - # else: - # prompt_messages.append(UserPromptMessage(content=message.query)) - if last_node_execution and last_node_execution.process_data: - try: - process_data = json.loads(last_node_execution.process_data) - if process_data.get("memory_type", "") == LLMMemoryType.INDEPENDENT: - for prompt in process_data.get("prompts", []): - if prompt.get("role") == "user": - prompt_messages.append( - UserPromptMessage( - content=prompt.get("content"), - ) - ) - elif prompt.get("role") == "assistant": - prompt_messages.append( - AssistantPromptMessage( - content=prompt.get("content"), - ) - ) - output = ( - json.loads(last_node_execution.outputs).get("text", "") if last_node_execution.outputs else "" + if workflow_run and workflow_run.workflow: + file_extra_config = FileUploadConfigManager.convert( + workflow_run.workflow.features_dict, is_vision=False ) - prompt_messages.append(AssistantPromptMessage(content=output)) - except json.JSONDecodeError: - pass - if not prompt_messages: - return [] + detail = ImagePromptMessageContent.DETAIL.LOW + app_record = self.conversation.app - # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - - if curr_message_tokens > max_token_limit: - pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: - pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - - return prompt_messages - - def get_history_prompt_text( - self, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None, - ) -> str: - """ - Get history prompt text. - :param human_prefix: human prefix - :param ai_prefix: ai prefix - :param max_token_limit: max token limit - :param message_limit: message limit - :return: - """ - prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) - - string_messages = [] - for m in prompt_messages: - if m.role == PromptMessageRole.USER: - role = human_prefix - elif m.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(m.content, list): - inner_msg = "" - for content in m.content: - if isinstance(content, TextPromptMessageContent): - inner_msg += f"{content.data}\n" - elif isinstance(content, ImagePromptMessageContent): - inner_msg += "[image]\n" - - string_messages.append(f"{role}: {inner_msg.strip()}") - else: - message = f"{role}: {m.content}" - string_messages.append(message) - - return "\n".join(string_messages) + if file_extra_config and app_record: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config + ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail + else: + file_objs = [] + return file_objs, detail diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..d2d80dc867 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -3,12 +3,12 @@ from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager +from core.memory.base_memory import BaseMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, - PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) @@ -20,7 +20,7 @@ from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import WorkflowRun -class TokenBufferMemory: +class TokenBufferMemory(BaseMemory): def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: self.conversation = conversation self.model_instance = model_instance @@ -129,44 +129,3 @@ class TokenBufferMemory: curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - - def get_history_prompt_text( - self, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None, - ) -> str: - """ - Get history prompt text. - :param human_prefix: human prefix - :param ai_prefix: ai prefix - :param max_token_limit: max token limit - :param message_limit: message limit - :return: - """ - prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) - - string_messages = [] - for m in prompt_messages: - if m.role == PromptMessageRole.USER: - role = human_prefix - elif m.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(m.content, list): - inner_msg = "" - for content in m.content: - if isinstance(content, TextPromptMessageContent): - inner_msg += f"{content.data}\n" - elif isinstance(content, ImagePromptMessageContent): - inner_msg += "[image]\n" - - string_messages.append(f"{role}: {inner_msg.strip()}") - else: - message = f"{role}: {m.content}" - string_messages.append(message) - - return "\n".join(string_messages) diff --git a/api/models/workflow.py b/api/models/workflow.py index da60617de5..9374afd8b0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -606,6 +606,17 @@ class WorkflowNodeExecution(Base): "triggered_from", "node_execution_id", ), + db.Index( + "workflow_node_execution_run_node_status_idx", + "workflow_run_id", + "node_id", + "status", + ), + db.Index( + "workflow_node_execution_run_status_idx", + "workflow_run_id", + "status", + ), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index 23edcb66ba..5f0ef1de76 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -195,6 +195,7 @@ const MemoryConfig: FC = ({ }) onChange(newPayload) }} + defaultValue={payload.type} /> {canSetRoleName && (