From 69b3e94630f9bd547e766b0bc3abafcbe381746f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 2 Mar 2026 01:55:49 +0800 Subject: [PATCH] refactor: inject workflow node memory via protocol (#32784) --- api/.importlinter | 4 - api/core/app/workflow/node_factory.py | 42 +++++++++- api/core/workflow/nodes/llm/llm_utils.py | 79 +++++++++++++------ api/core/workflow/nodes/llm/node.py | 36 +-------- .../parameter_extractor_node.py | 42 +++++----- .../question_classifier_node.py | 19 ++--- .../nodes/test_parameter_extractor.py | 16 ++-- 7 files changed, 130 insertions(+), 108 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index 3b1f58d886..74dec4a293 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -54,7 +54,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database # TODO(QuantumGhost): use DI to avoid depending on global DB. @@ -114,7 +113,6 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> core.model_manager core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.llm.llm_utils -> models.model core.workflow.nodes.llm.node -> core.tools.signature core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler core.workflow.nodes.tool.tool_node -> core.tools.tool_engine @@ -150,7 +148,6 @@ ignore_imports = core.workflow.nodes.llm.node -> core.model_manager core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities @@ -172,7 +169,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database core.workflow.nodes.human_input.human_input_node -> extensions.ext_database diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 41b8c9fd7b..970b0c4c3d 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -1,6 +1,8 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast, final +from sqlalchemy import select +from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config @@ -11,6 +13,7 @@ from core.helper.code_executor.code_executor import ( CodeExecutor, ) from core.helper.ssrf_proxy import ssrf_proxy +from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -18,7 +21,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.graph_config import NodeConfigDict -from core.workflow.enums import NodeType +from core.workflow.enums import NodeType, SystemVariableKey from core.workflow.file.file_manager import file_manager from core.workflow.graph.graph import NodeFactory from core.workflow.nodes.base.node import Node @@ -29,7 +32,6 @@ from core.workflow.nodes.datasource import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ModelConfig from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from core.workflow.nodes.llm.node import LLMNode @@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.variables.segments import StringSegment +from extensions.ext_database import db +from models.model import Conversation if TYPE_CHECKING: from core.workflow.entities import GraphInitParams from core.workflow.runtime import GraphRuntimeState +def fetch_memory( + *, + conversation_id: str | None, + app_id: str, + node_data_memory: MemoryConfig | None, + model_instance: ModelInstance, +) -> TokenBufferMemory | None: + if not node_data_memory or not conversation_id: + return None + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + class DefaultWorkflowCodeExecutor: def execute( self, @@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory): if node_type == NodeType.QUESTION_CLASSIFIER: model_instance = self._build_model_instance_for_llm_node(node_data) + memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) return QuestionClassifierNode( id=node_id, config=node_config, @@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory): credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, model_instance=model_instance, + memory=memory, ) if node_type == NodeType.PARAMETER_EXTRACTOR: model_instance = self._build_model_instance_for_llm_node(node_data) + memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) return ParameterExtractorNode( id=node_id, config=node_config, @@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory): credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, model_instance=model_instance, + memory=memory, ) return node_class( @@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory): return None node_memory = MemoryConfig.model_validate(raw_memory_config) - return llm_utils.fetch_memory( - variable_pool=self.graph_runtime_state.variable_pool, + conversation_id_variable = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + conversation_id = ( + conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None + ) + return fetch_memory( + conversation_id=conversation_id, app_id=self.graph_init_params.app_id, node_data_memory=node_memory, model_instance=model_instance, diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index b751640e1b..7e52a1a202 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,22 +1,21 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities import PromptMessageRole +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.enums import SystemVariableKey from core.workflow.file.models import File from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment -from extensions.ext_database import db -from models.model import Conversation +from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment from .exc import InvalidVariableTypeError +from .protocols import PromptMessageMemory def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: @@ -42,23 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") -def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance -) -> TokenBufferMemory | None: - if not node_data_memory: - return None +def convert_history_messages_to_text( + *, + history_messages: Sequence[PromptMessage], + human_prefix: str, + ai_prefix: str, +) -> str: + string_messages: list[str] = [] + for message in history_messages: + if message.role == PromptMessageRole.USER: + role = human_prefix + elif message.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue - # get conversation id - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value + if isinstance(message.content, list): + content_parts = [] + for content in message.content: + if isinstance(content, TextPromptMessageContent): + content_parts.append(content.data) + elif isinstance(content, ImagePromptMessageContent): + content_parts.append("[image]") - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) - conversation = session.scalar(stmt) - if not conversation: - return None + inner_msg = "\n".join(content_parts) + string_messages.append(f"{role}: {inner_msg}") + else: + string_messages.append(f"{role}: {message.content}") - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - return memory + return "\n".join(string_messages) + + +def fetch_memory_text( + *, + memory: PromptMessageMemory, + max_token_limit: int, + message_limit: int | None = None, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", +) -> str: + history_messages = memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=message_limit, + ) + return convert_history_messages_to_text( + history_messages=history_messages, + human_prefix=human_prefix, + ai_prefix=ai_prefix, + ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 4378201eee..475a904d1c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1338,48 +1338,16 @@ def _handle_memory_completion_mode( ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_messages = memory.get_history_prompt_messages( + memory_text = llm_utils.fetch_memory_text( + memory=memory, max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - memory_text = _convert_history_messages_to_text( - history_messages=memory_messages, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, ) return memory_text -def _convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - return "\n".join(string_messages) - - def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index af3a4cdad3..3353a163ad 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,7 +5,6 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ImagePromptMessageContent from core.model_runtime.entities.llm_entities import LLMUsage @@ -24,12 +23,17 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import ( + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from core.workflow.file import File from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.node import Node from core.workflow.nodes.llm import llm_utils +from core.workflow.nodes.llm.protocols import PromptMessageMemory from core.workflow.runtime import VariablePool from core.workflow.variables.types import ArrayValidation, SegmentType from factories.variable_factory import build_segment_with_type @@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): _model_instance: ModelInstance _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" + _memory: PromptMessageMemory | None def __init__( self, @@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", model_instance: ModelInstance, + memory: PromptMessageMemory | None = None, ) -> None: super().__init__( id=id, @@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory self._model_instance = model_instance + self._memory = memory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc - # fetch memory - memory = llm_utils.fetch_memory( - variable_pool=variable_pool, - app_id=self.app_id, - node_data_memory=node_data.memory, - model_instance=model_instance, - ) + memory = self._memory if ( set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} @@ -316,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): query: str, variable_pool: VariablePool, model_instance: ModelInstance, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: @@ -404,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): query: str, variable_pool: VariablePool, model_instance: ModelInstance, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: @@ -442,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): query: str, variable_pool: VariablePool, model_instance: ModelInstance, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: @@ -467,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, context="", memory_config=node_data.memory, - memory=memory, + # AdvancedPromptTransform is still typed against TokenBufferMemory. + memory=cast(Any, memory), model_instance=model_instance, image_detail_config=vision_detail, ) @@ -480,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): query: str, variable_pool: VariablePool, model_instance: ModelInstance, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: @@ -712,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode(node_data.model.mode) @@ -721,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): instruction = variable_pool.convert_template(node_data.instruction or "").text if memory and node_data.memory and node_data.memory.window: - memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + memory_str = llm_utils.fetch_memory_text( + memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( @@ -739,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -748,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): instruction = variable_pool.convert_template(node_data.instruction or "").text if memory and node_data.memory and node_data.memory.window: - memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + memory_str = llm_utils.fetch_memory_text( + memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 5d5edcc0f7..789ff605cc 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -3,7 +3,6 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder @@ -27,7 +26,7 @@ from core.workflow.nodes.llm import ( llm_utils, ) from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -56,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" _model_instance: ModelInstance + _memory: PromptMessageMemory | None def __init__( self, @@ -67,6 +67,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", model_instance: ModelInstance, + memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -81,6 +82,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory self._model_instance = model_instance + self._memory = memory if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -103,13 +105,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model instance model_instance = self._model_instance - # fetch memory - memory = llm_utils.fetch_memory( - variable_pool=variable_pool, - app_id=self.app_id, - node_data_memory=node_data.memory, - model_instance=model_instance, - ) + memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text @@ -327,7 +323,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -340,7 +336,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): input_text = query memory_str = "" if memory: - memory_str = memory.get_history_prompt_text( + memory_str = llm_utils.fetch_memory_text( + memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e791f12393..773074e92d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.model_manager import ModelInstance -from core.model_runtime.entities import AssistantPromptMessage +from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory @@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod def get_mocked_fetch_memory(memory_text: str): class MemoryMock: - def get_history_prompt_text( + def get_history_prompt_messages( self, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", max_token_limit: int = 2000, message_limit: int | None = None, ): - return memory_text + return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")] return MagicMock(return_value=MemoryMock()) -def init_parameter_extractor_node(config: dict): +def init_parameter_extractor_node(config: dict, memory=None): graph_config = { "edges": [ { @@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict): credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), + memory=memory, ) return node @@ -350,7 +349,7 @@ def test_extract_json_from_tool_call(): assert result["location"] == "kawaii" -def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): +def test_chat_parameter_extractor_with_memory(setup_model_mock): """ Test chat parameter extractor with memory. """ @@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): "memory": {"window": {"enabled": True, "size": 50}}, }, }, + memory=get_mocked_fetch_memory("customized memory")(), ) node._model_instance = get_mocked_fetch_model_instance( @@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - # Test the mock before running the actual test - monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) db.session.close = MagicMock() result = node._run()