refactor: inject memory interface into LLMNode (#32754)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2026-03-01 04:05:18 +08:00 committed by GitHub
parent 1f0fca89a8
commit c034eb036c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 14 deletions

View File

@ -12,6 +12,7 @@ from core.helper.ssrf_proxy import ssrf_proxy
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
@ -26,9 +27,11 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@ -177,6 +180,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
@ -185,6 +189,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.DATASOURCE:
@ -278,3 +283,21 @@ class DifyNodeFactory(NodeFactory):
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,
)

View File

@ -14,7 +14,6 @@ from sqlalchemy import select
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
ImagePromptMessageContent,
@ -63,7 +62,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
@ -115,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@ -126,6 +126,7 @@ class LLMNode(Node[LLMNodeData]):
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -140,6 +141,7 @@ class LLMNode(Node[LLMNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -208,13 +210,7 @@ class LLMNode(Node[LLMNodeData]):
model_provider = model_instance.provider
model_stop = model_instance.stop
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
memory = self._memory
query: str | None = None
if self.node_data.memory:
@ -762,7 +758,7 @@ class LLMNode(Node[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
@ -1307,7 +1303,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
@ -1327,7 +1323,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
@ -1340,15 +1336,48 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@ -19,3 +21,13 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -12,6 +12,7 @@ from core.entities.provider_entities import CustomConfiguration, SystemConfigura
from core.model_manager import ModelInstance
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
@ -20,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities import GraphInitParams
from core.workflow.file import File, FileTransferMethod, FileType
from core.workflow.nodes.llm import llm_utils
@ -32,7 +34,7 @@ from core.workflow.nodes.llm.entities import (
VisionConfigOptions,
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -587,6 +589,41 @@ def test_handle_list_messages_basic(llm_node):
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
def test_handle_memory_completion_mode_uses_prompt_message_interface():
memory = mock.MagicMock(spec=MockTokenBufferMemory)
memory.get_history_prompt_messages.return_value = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="first question"),
ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/png",
),
]
),
AssistantPromptMessage(content="first answer"),
]
model_instance = mock.MagicMock(spec=ModelInstance)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=3),
)
with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
assert memory_text == "Human: first question\n[image]\nAssistant: first answer"
mock_rest_token.assert_called_once_with(prompt_messages=[], model_instance=model_instance)
memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=3)
@pytest.fixture
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)