feat: add memory update check in AdvancedChatAppRunner

This commit is contained in:
Stream 2025-08-21 14:24:17 +08:00
parent 7ffcf8dd6f
commit 635c4ed4ce
No known key found for this signature in database
GPG Key ID: 9475891C9507B4F3
1 changed files with 77 additions and 1 deletions

View File

@ -5,11 +5,15 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -67,6 +71,8 @@ from core.workflow.nodes.event import (
RunStreamChunkEvent,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import Workflow, db
from services.chatflow_memory_service import ChatflowMemoryService
from . import llm_utils
from .entities import (
@ -290,6 +296,11 @@ class LLMNode(BaseNode):
if self._file_outputs is not None:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
try:
self._handle_chatflow_memory(result_text, variable_pool)
except Exception as e:
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -1078,6 +1089,71 @@ class LLMNode(BaseNode):
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
if not self._node_data.memory or self._node_data.memory.mode != "block":
return
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
if not conversation_id_segment:
raise ValueError("Conversation ID not found in variable pool.")
conversation_id = conversation_id_segment.text
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not user_query_segment:
raise ValueError("User query not found in variable pool.")
user_query = user_query_segment.text
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from services.chatflow_history_service import ChatflowHistoryService
ChatflowHistoryService.save_node_message(
prompt_message=(UserPromptMessage(content=user_query)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
ChatflowHistoryService.save_node_message(
prompt_message=(AssistantPromptMessage(content=llm_output)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(
Workflow.tenant_id == self.tenant_id,
Workflow.app_id == self.app_id
)
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id),None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
memory_block_spec=memory_block_spec,
node_id=self.node_id,
llm_output=llm_output,
variable_pool=variable_pool,
is_draft=is_draft
)
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole