mirror of https://github.com/langgenius/dify.git
feat: add independent memory
This commit is contained in:
parent
48be8fb6cc
commit
a9bae7aafd
|
|
@ -1,18 +1,64 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
class BaseMemory:
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(self) -> Sequence[PromptMessage]:
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get the history prompt messages
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_text(self) -> str:
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the history prompt text
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
|
|
|||
|
|
@ -1,26 +1,28 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import LLMMemoryType
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun
|
||||
|
||||
|
||||
class ModelContextMemory:
|
||||
class ModelContextMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.node_id = node_id
|
||||
|
|
@ -34,8 +36,104 @@ class ModelContextMemory:
|
|||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
thread_messages = list(reversed(self._fetch_thread_messages(message_limit)))
|
||||
if not thread_messages:
|
||||
return []
|
||||
# Get all required workflow_run_ids
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in thread_messages]
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
# Batch query all related WorkflowNodeExecution records
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.workflow_run_id.in_(workflow_run_ids),
|
||||
WorkflowNodeExecution.node_id == self.node_id,
|
||||
WorkflowNodeExecution.status.in_(
|
||||
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from workflow_run_id to node_execution
|
||||
node_execution_map = {ne.workflow_run_id: ne for ne in node_executions}
|
||||
|
||||
# Get the last node_execution
|
||||
last_node_execution = node_execution_map.get(thread_messages[-1].workflow_run_id)
|
||||
prompt_messages = self._get_prompt_messages_in_process_data(last_node_execution)
|
||||
|
||||
# Batch query all message-related files
|
||||
message_ids = [msg.id for msg in thread_messages]
|
||||
all_files = db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_ids)).all()
|
||||
|
||||
# Create mapping from message_id to files
|
||||
files_map = {}
|
||||
for file in all_files:
|
||||
if file.message_id not in files_map:
|
||||
files_map[file.message_id] = []
|
||||
files_map[file.message_id].append(file)
|
||||
|
||||
for message in thread_messages:
|
||||
files = files_map.get(message.id, [])
|
||||
node_execution = node_execution_map.get(message.workflow_run_id)
|
||||
if node_execution and files:
|
||||
file_objs, detail = self._handle_file(message, files)
|
||||
if file_objs:
|
||||
outputs = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
if not outputs:
|
||||
continue
|
||||
if outputs not in [prompt.content for prompt in prompt_messages]:
|
||||
continue
|
||||
outputs_index = [prompt.content for prompt in prompt_messages].index(outputs)
|
||||
prompt_index = outputs_index - 1
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
content = cast(str, prompt_messages[prompt_index].content)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=content))
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_messages[prompt_index].content = prompt_message_contents
|
||||
return prompt_messages
|
||||
|
||||
def _get_prompt_messages_in_process_data(
|
||||
self,
|
||||
node_execution: WorkflowNodeExecution,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get prompt messages in process data.
|
||||
:param node_execution: node execution
|
||||
:return: prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if not node_execution.process_data:
|
||||
return []
|
||||
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
if process_data.get("memory_type", "") != LLMMemoryType.INDEPENDENT:
|
||||
return []
|
||||
prompts = process_data.get("prompts", [])
|
||||
for prompt in prompts:
|
||||
prompt_content = prompt.get("text", "")
|
||||
if prompt.get("role", "") == "user":
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_content))
|
||||
elif prompt.get("role", "") == "assistant":
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt_content))
|
||||
output = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
prompt_messages.append(AssistantPromptMessage(content=output))
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return prompt_messages
|
||||
|
||||
def _fetch_thread_messages(self, message_limit: int | None = None) -> list[Message]:
|
||||
"""
|
||||
Fetch thread messages.
|
||||
:param message_limit: message limit
|
||||
:return: thread messages
|
||||
"""
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
|
|
@ -59,147 +157,44 @@ class ModelContextMemory:
|
|||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
# fetch the thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
if len(thread_messages) == 0:
|
||||
if not thread_messages:
|
||||
return []
|
||||
last_thread_message = list(reversed(thread_messages))[0]
|
||||
last_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.workflow_run_id == last_thread_message.workflow_run_id,
|
||||
WorkflowNodeExecution.node_id == self.node_id,
|
||||
WorkflowNodeExecution.status.in_(
|
||||
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
|
||||
),
|
||||
)
|
||||
.order_by(WorkflowNodeExecution.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
return thread_messages
|
||||
|
||||
# files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
# if files:
|
||||
# file_extra_config = None
|
||||
# if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
# file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
# else:
|
||||
# if message.workflow_run_id:
|
||||
# workflow_run = (
|
||||
# db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
# )
|
||||
def _handle_file(self, message: Message, files: list[MessageFile]):
|
||||
"""
|
||||
Handle file for memory.
|
||||
:param message: message
|
||||
:param files: files
|
||||
:return: file objects and detail
|
||||
"""
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
|
||||
# if workflow_run and workflow_run.workflow:
|
||||
# file_extra_config = FileUploadConfigManager.convert(
|
||||
# workflow_run.workflow.features_dict, is_vision=False
|
||||
# )
|
||||
|
||||
# detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
# if file_extra_config and app_record:
|
||||
# file_objs = file_factory.build_from_message_files(
|
||||
# message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
# )
|
||||
# if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
# detail = file_extra_config.image_config.detail
|
||||
# else:
|
||||
# file_objs = []
|
||||
|
||||
# if not file_objs:
|
||||
# prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
# else:
|
||||
# prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
# prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
# for file in file_objs:
|
||||
# prompt_message = file_manager.to_prompt_message_content(
|
||||
# file,
|
||||
# image_detail_config=detail,
|
||||
# )
|
||||
# prompt_message_contents.append(prompt_message)
|
||||
|
||||
# prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
# else:
|
||||
# prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
if last_node_execution and last_node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(last_node_execution.process_data)
|
||||
if process_data.get("memory_type", "") == LLMMemoryType.INDEPENDENT:
|
||||
for prompt in process_data.get("prompts", []):
|
||||
if prompt.get("role") == "user":
|
||||
prompt_messages.append(
|
||||
UserPromptMessage(
|
||||
content=prompt.get("content"),
|
||||
)
|
||||
)
|
||||
elif prompt.get("role") == "assistant":
|
||||
prompt_messages.append(
|
||||
AssistantPromptMessage(
|
||||
content=prompt.get("content"),
|
||||
)
|
||||
)
|
||||
output = (
|
||||
json.loads(last_node_execution.outputs).get("text", "") if last_node_execution.outputs else ""
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=output))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
app_record = self.conversation.app
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
return file_objs, detail
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from typing import Optional
|
|||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
|
@ -20,7 +20,7 @@ from models.model import AppMode, Conversation, Message, MessageFile
|
|||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
|
@ -129,44 +129,3 @@ class TokenBufferMemory:
|
|||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
|
|
|||
|
|
@ -606,6 +606,17 @@ class WorkflowNodeExecution(Base):
|
|||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_node_status_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"status",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_status_idx",
|
||||
"workflow_run_id",
|
||||
"status",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@ const MemoryConfig: FC<Props> = ({
|
|||
})
|
||||
onChange(newPayload)
|
||||
}}
|
||||
defaultValue={payload.type}
|
||||
/>
|
||||
</div>
|
||||
{canSetRoleName && (
|
||||
|
|
|
|||
Loading…
Reference in New Issue