mirror of https://github.com/langgenius/dify.git
feat: add memory update check in AdvancedChatAppRunner
This commit is contained in:
parent
7ffcf8dd6f
commit
635c4ed4ce
|
|
@ -5,11 +5,15 @@ import logging
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.entities import MemoryScope
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
|
|
@ -67,6 +71,8 @@ from core.workflow.nodes.event import (
|
|||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models import Workflow, db
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
|
|
@ -290,6 +296,11 @@ class LLMNode(BaseNode):
|
|||
if self._file_outputs is not None:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
try:
|
||||
self._handle_chatflow_memory(result_text, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -1078,6 +1089,71 @@ class LLMNode(BaseNode):
|
|||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
|
||||
if not self._node_data.memory or self._node_data.memory.mode != "block":
|
||||
return
|
||||
|
||||
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
|
||||
if not conversation_id_segment:
|
||||
raise ValueError("Conversation ID not found in variable pool.")
|
||||
conversation_id = conversation_id_segment.text
|
||||
|
||||
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not user_query_segment:
|
||||
raise ValueError("User query not found in variable pool.")
|
||||
user_query = user_query_segment.text
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(UserPromptMessage(content=user_query)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(AssistantPromptMessage(content=llm_output)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
|
||||
memory_config = self._node_data.memory
|
||||
if not memory_config:
|
||||
return
|
||||
block_ids = memory_config.block_id
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == self.tenant_id,
|
||||
Workflow.app_id == self.app_id
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found.")
|
||||
memory_blocks = workflow.memory_blocks
|
||||
|
||||
for block_id in block_ids:
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id),None)
|
||||
|
||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
ChatflowMemoryService.update_node_memory_if_needed(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
node_id=self.node_id,
|
||||
llm_output=llm_output,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
|
|
|
|||
Loading…
Reference in New Issue