feat: refactor: refactor from ChatflowHistoryService and ChatflowMemoryService

This commit is contained in:
Stream 2025-08-22 15:33:45 +08:00
parent f72ed4898c
commit 4d2fc66a8d
No known key found for this signature in database
GPG Key ID: 033728094B100D70
7 changed files with 80 additions and 147 deletions

View File

@ -1,9 +1,10 @@
import logging
from collections.abc import Mapping, MutableMapping
from typing import Any, Optional, cast, override
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
@ -417,6 +418,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
memory_block_specs=memory_block_specs,
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft
)

View File

@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Optional, cast, Mapping
from typing import Optional, cast
import json_repair
@ -14,9 +14,10 @@ from core.llm_generator.prompts import (
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
MEMORY_UPDATE_PROMPT,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, MEMORY_UPDATE_PROMPT,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.memory.entities import MemoryBlock, MemoryBlockSpec
from core.model_manager import ModelManager
@ -577,7 +578,7 @@ class LLMGenerator:
@staticmethod
def update_memory_block(
tenant_id: str,
visible_history: Mapping[str, str],
visible_history: Sequence[tuple[str, str]],
memory_block: MemoryBlock,
memory_spec: MemoryBlockSpec
) -> str:
@ -588,7 +589,7 @@ class LLMGenerator:
model_type=ModelType.LLM,
)
formatted_history = ""
for sender, message in visible_history.items():
for sender, message in visible_history:
formatted_history += f"{sender}: {message}\n"
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
inputs={

View File

@ -436,4 +436,4 @@ Update instruction:
{{instruction}}
Please output only the updated memory content, no other text like greeting:
""" # noqa: E501
"""

View File

@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Any, Optional
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel, Field

View File

@ -1,16 +1,14 @@
import json
import time
from collections.abc import Sequence
from typing import Literal, Optional, overload, MutableMapping
from collections.abc import MutableMapping, Sequence
from typing import Literal, Optional, overload
from sqlalchemy import Row, Select, and_, func, select
from sqlalchemy.orm import Session
from core.memory.entities import ChatflowConversationMetadata
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
UserPromptMessage,
)
from extensions.ext_database import db
from models.chatflow_memory import ChatflowConversation, ChatflowMessage

View File

@ -7,6 +7,7 @@ from typing import Optional, cast
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.entities import (
MemoryBlock,
MemoryBlockSpec,
@ -16,7 +17,7 @@ from core.memory.entities import (
MemoryTerm,
)
from core.memory.errors import MemorySyncTimeoutError
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.entities.message_entities import PromptMessage
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
@ -102,9 +103,9 @@ class ChatflowMemoryService:
@staticmethod
def get_memories_by_specs(memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str, app_id: str,
conversation_id: Optional[str] = None,
node_id: Optional[str] = None,
is_draft: bool = False) -> Sequence[MemoryBlock]:
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool) -> Sequence[MemoryBlock]:
return [ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, conversation_id, node_id, is_draft
) for spec in memory_block_specs]
@ -112,9 +113,9 @@ class ChatflowMemoryService:
@staticmethod
def get_memory_by_spec(spec: MemoryBlockSpec,
tenant_id: str, app_id: str,
conversation_id: Optional[str] = None,
node_id: Optional[str] = None,
is_draft: bool = False) -> MemoryBlock:
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool) -> MemoryBlock:
with (Session(bind=db.engine) as session):
if is_draft:
draft_var_service = WorkflowDraftVariableService(session)
@ -161,34 +162,6 @@ class ChatflowMemoryService:
node_id=node_id
)
@staticmethod
def get_app_memories_by_workflow(workflow, tenant_id: str,
conversation_id: Optional[str] = None) -> Sequence[MemoryBlock]:
app_memory_specs = [spec for spec in workflow.memory_blocks if spec.scope == MemoryScope.APP]
return ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=app_memory_specs,
tenant_id=tenant_id,
app_id=workflow.app_id,
conversation_id=conversation_id
)
@staticmethod
def get_node_memories_by_workflow(workflow, node_id: str, tenant_id: str) -> Sequence[MemoryBlock]:
"""Get node-scoped memories based on workflow configuration"""
from core.memory.entities import MemoryScope
node_memory_specs = [
spec for spec in workflow.memory_blocks
if spec.scope == MemoryScope.NODE and spec.id == node_id
]
return ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=node_memory_specs,
tenant_id=tenant_id,
app_id=workflow.app_id,
node_id=node_id
)
@staticmethod
def update_node_memory_if_needed(
tenant_id: str,
@ -199,28 +172,36 @@ class ChatflowMemoryService:
variable_pool: VariablePool,
is_draft: bool
) -> bool:
"""Update node-level memory after LLM execution"""
conversation_id_segment = variable_pool.get(('sys', 'conversation_id'))
if not conversation_id_segment:
return False
conversation_id = conversation_id_segment.value
if not ChatflowMemoryService._should_update_memory(
tenant_id, app_id, memory_block_spec, str(conversation_id), node_id
tenant_id=tenant_id,
app_id=app_id,
memory_block_spec=memory_block_spec,
conversation_id=conversation_id,
node_id=node_id
):
return False
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
# Node-level sync: blocking execution
ChatflowMemoryService._update_node_memory_sync(
tenant_id, app_id, memory_block_spec, node_id,
str(conversation_id), variable_pool, is_draft
tenant_id=tenant_id,
app_id=app_id,
memory_block_spec=memory_block_spec,
node_id=node_id,
conversation_id=conversation_id,
variable_pool=variable_pool,
is_draft=is_draft
)
else:
# Node-level async: execute asynchronously
ChatflowMemoryService._update_node_memory_async(
tenant_id, app_id, memory_block_spec, node_id,
llm_output, str(conversation_id), variable_pool, is_draft
tenant_id=tenant_id,
app_id=app_id,
memory_block_spec=memory_block_spec,
node_id=node_id,
conversation_id=conversation_id,
variable_pool=variable_pool,
is_draft=is_draft
)
return True
@ -364,12 +345,14 @@ class ChatflowMemoryService:
# Node-level async update method
@staticmethod
def _update_node_memory_async(tenant_id: str, app_id: str,
memory_block_spec: MemoryBlockSpec,
node_id: str, llm_output: str,
conversation_id: str,
variable_pool: VariablePool,
is_draft: bool = False):
def _update_node_memory_async(
tenant_id: str,
app_id: str,
memory_block_spec: MemoryBlockSpec,
node_id: str,
conversation_id: str,
variable_pool: VariablePool,
is_draft: bool = False):
"""Asynchronously update node memory (submit task)"""
# Execute update asynchronously using thread
@ -380,7 +363,6 @@ class ChatflowMemoryService:
'tenant_id': tenant_id,
'app_id': app_id,
'node_id': node_id,
'llm_output': llm_output,
'variable_pool': variable_pool,
'is_draft': is_draft
},
@ -390,10 +372,15 @@ class ChatflowMemoryService:
# Return immediately without waiting
@staticmethod
def _perform_node_memory_update(*, memory_block_spec: MemoryBlockSpec,
tenant_id: str, app_id: str, node_id: str,
llm_output: str, variable_pool: VariablePool,
is_draft: bool = False):
def _perform_node_memory_update(
*,
memory_block_spec: MemoryBlockSpec,
tenant_id: str,
app_id: str,
node_id: str,
variable_pool: VariablePool,
is_draft: bool = False
):
ChatflowMemoryService._perform_memory_update(
tenant_id=tenant_id,
app_id=app_id,
@ -422,35 +409,36 @@ class ChatflowMemoryService:
)
@staticmethod
def _perform_memory_update(tenant_id: str, app_id: str,
memory_block_spec: MemoryBlockSpec,
conversation_id: str, variable_pool: VariablePool,
node_id: Optional[str] = None,
is_draft: bool = False):
"""Perform the actual memory update using LLM"""
def _perform_memory_update(
tenant_id: str, app_id: str,
memory_block_spec: MemoryBlockSpec,
conversation_id: str,
variable_pool: VariablePool,
node_id: Optional[str],
is_draft: bool):
history = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id,
)
# Get current memory value
current_memory = ChatflowMemoryService._get_memory_from_chatflow_table(
memory_id=memory_block_spec.id,
memory_block = ChatflowMemoryService.get_memory_by_spec(
tenant_id=tenant_id,
spec=memory_block_spec,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id
node_id=node_id,
is_draft=is_draft
)
updated_value = LLMGenerator.update_memory_block(
tenant_id=tenant_id,
visible_history=ChatflowMemoryService._format_chat_history(history),
memory_block=memory_block,
memory_spec=memory_block_spec,
)
current_value = current_memory.value if current_memory else memory_block_spec.template
# Save updated memory
updated_memory = MemoryBlock(
id=current_memory.id if current_memory else "",
id=memory_block.id,
memory_id=memory_block_spec.id,
name=memory_block_spec.name,
value=updated_value,
@ -460,74 +448,17 @@ class ChatflowMemoryService:
conversation_id=conversation_id if memory_block_spec.term == MemoryTerm.SESSION else None,
node_id=node_id
)
ChatflowMemoryService.save_memory(updated_memory, tenant_id, variable_pool, is_draft)
# Not implemented yet: Send success event
# self._send_memory_update_event(memory_block_spec.id, "completed", updated_value)
@staticmethod
def _invoke_llm_for_memory_update(tenant_id: str,
memory_block_spec: MemoryBlockSpec,
prompt: str, current_value: str) -> Optional[str]:
"""Invoke LLM to update memory content
Args:
tenant_id: Tenant ID
memory_block_spec: Memory block specification
prompt: Update prompt
current_value: Current memory value (used for fallback on failure)
Returns:
Updated value, returns None if failed
"""
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.model_entities import ModelType
model_manager = ModelManager()
# Use model configuration defined in memory_block_spec, use default model if not specified
if hasattr(memory_block_spec, 'model') and memory_block_spec.model:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=memory_block_spec.model.get("provider", ""),
model=memory_block_spec.model.get("name", "")
)
model_parameters = memory_block_spec.model.get("completion_params", {})
else:
# Use default model
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_parameters = {"temperature": 0.7, "max_tokens": 1000}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content=prompt)],
model_parameters=model_parameters,
stream=False
)
)
return response.message.get_text_content()
except Exception as e:
logger.exception("Failed to update memory using LLM", exc_info=e)
# Not implemented yet: Send failure event
# ChatflowMemoryService._send_memory_update_event(memory_block_spec.id, "failed", current_value, str(e))
return None
def _send_memory_update_event(self, memory_id: str, status: str, value: str, error: str = ""):
"""Send memory update event
Note: Event system integration not implemented yet, this method is retained as a placeholder
"""
# Not implemented yet: Event system integration will be added in future versions
pass
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
result = []
for message in messages:
result.append((str(message.role.value), message.get_text_content()))
return result
# App-level sync batch update related methods
@staticmethod

View File

@ -756,6 +756,7 @@ def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: boo
memory_block_specs=memory_block_specs,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
)