diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 970b0c4c3d..3a82f0a45e 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -16,6 +16,7 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.memory import PromptMessageMemory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -35,7 +36,6 @@ from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import Kno from core.workflow.nodes.llm.entities import ModelConfig from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.llm.protocols import PromptMessageMemory from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode diff --git a/api/core/model_runtime/memory/__init__.py b/api/core/model_runtime/memory/__init__.py new file mode 100644 index 0000000000..2d954486c3 --- /dev/null +++ b/api/core/model_runtime/memory/__init__.py @@ -0,0 +1,3 @@ +from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory + +__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/core/model_runtime/memory/prompt_message_memory.py b/api/core/model_runtime/memory/prompt_message_memory.py new file mode 100644 index 0000000000..4491ddfd05 --- /dev/null +++ b/api/core/model_runtime/memory/prompt_message_memory.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Protocol + +from core.model_runtime.entities import PromptMessage + +DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 + + +class PromptMessageMemory(Protocol): + """Port for loading memory as prompt messages.""" + + def get_history_prompt_messages( + self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None + ) -> Sequence[PromptMessage]: + """Return historical prompt messages constrained by token/message limits.""" + ... diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 475a904d1c..c06db0dc16 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -37,6 +37,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.memory import PromptMessageMemory from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -62,7 +63,7 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import VariablePool from core.workflow.variables import ( ArrayFileSegment, diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/core/workflow/nodes/llm/protocols.py index 5bca04165a..8e0365299d 100644 --- a/api/core/workflow/nodes/llm/protocols.py +++ b/api/core/workflow/nodes/llm/protocols.py @@ -1,10 +1,8 @@ from __future__ import annotations -from collections.abc import Sequence from typing import Any, Protocol from core.model_manager import ModelInstance -from core.model_runtime.entities import PromptMessage class CredentialsProvider(Protocol): @@ -21,13 +19,3 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: """Create a model instance that is ready for schema lookup and invocation.""" ... - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages for LLM nodes.""" - - def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ...