refactor: inject workflow node memory via protocol (#32784)

This commit is contained in:
-LAN- 2026-03-02 01:55:49 +08:00 committed by GitHub
parent ef2b5d6107
commit 69b3e94630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 130 additions and 108 deletions

View File

@ -54,7 +54,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
# TODO(QuantumGhost): use DI to avoid depending on global DB.
@ -114,7 +113,6 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@ -150,7 +148,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
@ -172,7 +169,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database

View File

@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@ -11,6 +13,7 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -18,7 +21,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@ -29,7 +32,6 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,

View File

@ -1,22 +1,21 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from models.model import Conversation
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from .exc import InvalidVariableTypeError
from .protocols import PromptMessageMemory
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
@ -42,23 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)

View File

@ -1338,48 +1338,16 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@ -5,7 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@ -24,12 +23,17 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.runtime import VariablePool
from core.workflow.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@ -316,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@ -404,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -442,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -467,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -480,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -712,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -721,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -739,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -748,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(

View File

@ -3,7 +3,6 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
@ -27,7 +26,7 @@ from core.workflow.nodes.llm import (
llm_utils,
)
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -56,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@ -67,6 +67,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -81,6 +82,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -103,13 +105,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
@ -327,7 +323,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -340,7 +336,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
input_text = query
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
def get_mocked_fetch_memory(memory_text: str):
class MemoryMock:
def get_history_prompt_text(
def get_history_prompt_messages(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
):
return memory_text
return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")]
return MagicMock(return_value=MemoryMock())
def init_parameter_extractor_node(config: dict):
def init_parameter_extractor_node(config: dict, memory=None):
graph_config = {
"edges": [
{
@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict):
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
memory=memory,
)
return node
@ -350,7 +349,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor_with_memory(setup_model_mock):
"""
Test chat parameter extractor with memory.
"""
@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"memory": {"window": {"enabled": True, "size": 50}},
},
},
memory=get_mocked_fetch_memory("customized memory")(),
)
node._model_instance = get_mocked_fetch_model_instance(
@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()