mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
refactor: inject memory interface into LLMNode (#32754)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
1f0fca89a8
commit
c034eb036c
@ -12,6 +12,7 @@ from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
@ -26,9 +27,11 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import PromptMessageMemory
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
@ -177,6 +180,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -185,6 +189,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@ -278,3 +283,21 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_instance.stop = tuple(stop)
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance
|
||||
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
return llm_utils.fetch_memory(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -14,7 +14,6 @@ from sqlalchemy import select
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
@ -63,7 +62,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
@ -115,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,6 +126,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -140,6 +141,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -208,13 +210,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory = self._memory
|
||||
|
||||
query: str | None = None
|
||||
if self.node_data.memory:
|
||||
@ -762,7 +758,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
@ -1307,7 +1303,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@ -1327,7 +1323,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> str:
|
||||
@ -1340,15 +1336,48 @@ def _handle_memory_completion_mode(
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
memory_text = _convert_history_messages_to_text(
|
||||
history_messages=memory_messages,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
@ -19,3 +21,13 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages for LLM nodes."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
|
||||
@ -12,6 +12,7 @@ from core.entities.provider_entities import CustomConfiguration, SystemConfigura
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
@ -20,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.file import File, FileTransferMethod, FileType
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
@ -32,7 +34,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -587,6 +589,41 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="first question"),
|
||||
ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
),
|
||||
AssistantPromptMessage(content="first answer"),
|
||||
]
|
||||
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||
)
|
||||
|
||||
with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert memory_text == "Human: first question\n[image]\nAssistant: first answer"
|
||||
mock_rest_token.assert_called_once_with(prompt_messages=[], model_instance=model_instance)
|
||||
memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user