diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 159500a609..d02ca1ecbe 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -12,6 +12,7 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.graph_config import NodeConfigDict @@ -26,9 +27,11 @@ from core.workflow.nodes.datasource import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ModelConfig from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.protocols import PromptMessageMemory from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode @@ -177,6 +180,7 @@ class DifyNodeFactory(NodeFactory): if node_type == NodeType.LLM: model_instance = self._build_model_instance_for_llm_node(node_data) + memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) return LLMNode( id=node_id, config=node_config, @@ -185,6 +189,7 @@ class DifyNodeFactory(NodeFactory): credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, model_instance=model_instance, + memory=memory, ) if node_type == NodeType.DATASOURCE: @@ -278,3 +283,21 @@ class DifyNodeFactory(NodeFactory): model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance + + def _build_memory_for_llm_node( + self, + *, + node_data: Mapping[str, Any], + model_instance: ModelInstance, + ) -> PromptMessageMemory | None: + raw_memory_config = node_data.get("memory") + if raw_memory_config is None: + return None + + node_memory = MemoryConfig.model_validate(raw_memory_config) + return llm_utils.fetch_memory( + variable_pool=self.graph_runtime_state.variable_pool, + app_id=self.graph_init_params.app_id, + node_data_memory=node_memory, + model_instance=model_instance, + ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 33dd88156f..057a144e89 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -14,7 +14,6 @@ from sqlalchemy import select from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( ImagePromptMessageContent, @@ -63,7 +62,7 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory from core.workflow.runtime import VariablePool from core.workflow.variables import ( ArrayFileSegment, @@ -115,6 +114,7 @@ class LLMNode(Node[LLMNodeData]): _credentials_provider: CredentialsProvider _model_factory: ModelFactory _model_instance: ModelInstance + _memory: PromptMessageMemory | None def __init__( self, @@ -126,6 +126,7 @@ class LLMNode(Node[LLMNodeData]): credentials_provider: CredentialsProvider, model_factory: ModelFactory, model_instance: ModelInstance, + memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -140,6 +141,7 @@ class LLMNode(Node[LLMNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory self._model_instance = model_instance + self._memory = memory if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -208,13 +210,7 @@ class LLMNode(Node[LLMNodeData]): model_provider = model_instance.provider model_stop = model_instance.stop - # fetch memory - memory = llm_utils.fetch_memory( - variable_pool=variable_pool, - app_id=self.app_id, - node_data_memory=self.node_data.memory, - model_instance=model_instance, - ) + memory = self._memory query: str | None = None if self.node_data.memory: @@ -762,7 +758,7 @@ class LLMNode(Node[LLMNodeData]): sys_query: str | None = None, sys_files: Sequence[File], context: str | None = None, - memory: TokenBufferMemory | None = None, + memory: PromptMessageMemory | None = None, model_instance: ModelInstance, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, @@ -1307,7 +1303,7 @@ def _calculate_rest_token( def _handle_memory_chat_mode( *, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, ) -> Sequence[PromptMessage]: @@ -1327,7 +1323,7 @@ def _handle_memory_chat_mode( def _handle_memory_completion_mode( *, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, ) -> str: @@ -1340,15 +1336,48 @@ def _handle_memory_completion_mode( ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = memory.get_history_prompt_text( + memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + memory_text = _convert_history_messages_to_text( + history_messages=memory_messages, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, ) return memory_text +def _convert_history_messages_to_text( + *, + history_messages: Sequence[PromptMessage], + human_prefix: str, + ai_prefix: str, +) -> str: + string_messages: list[str] = [] + for message in history_messages: + if message.role == PromptMessageRole.USER: + role = human_prefix + elif message.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(message.content, list): + content_parts = [] + for content in message.content: + if isinstance(content, TextPromptMessageContent): + content_parts.append(content.data) + elif isinstance(content, ImagePromptMessageContent): + content_parts.append("[image]") + + inner_msg = "\n".join(content_parts) + string_messages.append(f"{role}: {inner_msg}") + else: + string_messages.append(f"{role}: {message.content}") + return "\n".join(string_messages) + + def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/core/workflow/nodes/llm/protocols.py index 8e0365299d..5bca04165a 100644 --- a/api/core/workflow/nodes/llm/protocols.py +++ b/api/core/workflow/nodes/llm/protocols.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Sequence from typing import Any, Protocol from core.model_manager import ModelInstance +from core.model_runtime.entities import PromptMessage class CredentialsProvider(Protocol): @@ -19,3 +21,13 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: """Create a model instance that is ready for schema lookup and invocation.""" ... + + +class PromptMessageMemory(Protocol): + """Port for loading memory as prompt messages for LLM nodes.""" + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: int | None = None + ) -> Sequence[PromptMessage]: + """Return historical prompt messages constrained by token/message limits.""" + ... diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index e505aed323..3c365a6a0e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -12,6 +12,7 @@ from core.entities.provider_entities import CustomConfiguration, SystemConfigura from core.model_manager import ModelInstance from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, @@ -20,6 +21,7 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.entities import GraphInitParams from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.nodes.llm import llm_utils @@ -32,7 +34,7 @@ from core.workflow.nodes.llm.entities import ( VisionConfigOptions, ) from core.workflow.nodes.llm.file_saver import LLMFileSaver -from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -587,6 +589,41 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] +def test_handle_memory_completion_mode_uses_prompt_message_interface(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="first question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="first answer"), + ] + + model_instance = mock.MagicMock(spec=ModelInstance) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=3), + ) + + with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + + assert memory_text == "Human: first question\n[image]\nAssistant: first answer" + mock_rest_token.assert_called_once_with(prompt_messages=[], model_instance=model_instance) + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=3) + + @pytest.fixture def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)