mirror of https://github.com/langgenius/dify.git
feat: refactor: refactor from ChatflowHistoryService and ChatflowMemoryService
This commit is contained in:
parent
f72ed4898c
commit
4d2fc66a8d
|
|
@ -1,9 +1,10 @@
|
|||
import logging
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, Optional, cast, override
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
|
|
@ -417,6 +418,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast, Mapping
|
||||
from typing import Optional, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
|
|
@ -14,9 +14,10 @@ from core.llm_generator.prompts import (
|
|||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, MEMORY_UPDATE_PROMPT,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.memory.entities import MemoryBlock, MemoryBlockSpec
|
||||
from core.model_manager import ModelManager
|
||||
|
|
@ -577,7 +578,7 @@ class LLMGenerator:
|
|||
@staticmethod
|
||||
def update_memory_block(
|
||||
tenant_id: str,
|
||||
visible_history: Mapping[str, str],
|
||||
visible_history: Sequence[tuple[str, str]],
|
||||
memory_block: MemoryBlock,
|
||||
memory_spec: MemoryBlockSpec
|
||||
) -> str:
|
||||
|
|
@ -588,7 +589,7 @@ class LLMGenerator:
|
|||
model_type=ModelType.LLM,
|
||||
)
|
||||
formatted_history = ""
|
||||
for sender, message in visible_history.items():
|
||||
for sender, message in visible_history:
|
||||
formatted_history += f"{sender}: {message}\n"
|
||||
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
|
||||
inputs={
|
||||
|
|
|
|||
|
|
@ -436,4 +436,4 @@ Update instruction:
|
|||
{{instruction}}
|
||||
|
||||
Please output only the updated memory content, no other text like greeting:
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional, overload, MutableMapping
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from sqlalchemy import Row, Select, and_, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.entities import ChatflowConversationMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.chatflow_memory import ChatflowConversation, ChatflowMessage
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Optional, cast
|
|||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.entities import (
|
||||
MemoryBlock,
|
||||
MemoryBlockSpec,
|
||||
|
|
@ -16,7 +17,7 @@ from core.memory.entities import (
|
|||
MemoryTerm,
|
||||
)
|
||||
from core.memory.errors import MemorySyncTimeoutError
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -102,9 +103,9 @@ class ChatflowMemoryService:
|
|||
@staticmethod
|
||||
def get_memories_by_specs(memory_block_specs: Sequence[MemoryBlockSpec],
|
||||
tenant_id: str, app_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
is_draft: bool = False) -> Sequence[MemoryBlock]:
|
||||
conversation_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
is_draft: bool) -> Sequence[MemoryBlock]:
|
||||
return [ChatflowMemoryService.get_memory_by_spec(
|
||||
spec, tenant_id, app_id, conversation_id, node_id, is_draft
|
||||
) for spec in memory_block_specs]
|
||||
|
|
@ -112,9 +113,9 @@ class ChatflowMemoryService:
|
|||
@staticmethod
|
||||
def get_memory_by_spec(spec: MemoryBlockSpec,
|
||||
tenant_id: str, app_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
is_draft: bool = False) -> MemoryBlock:
|
||||
conversation_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
is_draft: bool) -> MemoryBlock:
|
||||
with (Session(bind=db.engine) as session):
|
||||
if is_draft:
|
||||
draft_var_service = WorkflowDraftVariableService(session)
|
||||
|
|
@ -161,34 +162,6 @@ class ChatflowMemoryService:
|
|||
node_id=node_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_app_memories_by_workflow(workflow, tenant_id: str,
|
||||
conversation_id: Optional[str] = None) -> Sequence[MemoryBlock]:
|
||||
|
||||
app_memory_specs = [spec for spec in workflow.memory_blocks if spec.scope == MemoryScope.APP]
|
||||
return ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=app_memory_specs,
|
||||
tenant_id=tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_node_memories_by_workflow(workflow, node_id: str, tenant_id: str) -> Sequence[MemoryBlock]:
|
||||
"""Get node-scoped memories based on workflow configuration"""
|
||||
from core.memory.entities import MemoryScope
|
||||
|
||||
node_memory_specs = [
|
||||
spec for spec in workflow.memory_blocks
|
||||
if spec.scope == MemoryScope.NODE and spec.id == node_id
|
||||
]
|
||||
return ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=node_memory_specs,
|
||||
tenant_id=tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
node_id=node_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_node_memory_if_needed(
|
||||
tenant_id: str,
|
||||
|
|
@ -199,28 +172,36 @@ class ChatflowMemoryService:
|
|||
variable_pool: VariablePool,
|
||||
is_draft: bool
|
||||
) -> bool:
|
||||
"""Update node-level memory after LLM execution"""
|
||||
conversation_id_segment = variable_pool.get(('sys', 'conversation_id'))
|
||||
if not conversation_id_segment:
|
||||
return False
|
||||
conversation_id = conversation_id_segment.value
|
||||
|
||||
if not ChatflowMemoryService._should_update_memory(
|
||||
tenant_id, app_id, memory_block_spec, str(conversation_id), node_id
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id
|
||||
):
|
||||
return False
|
||||
|
||||
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
|
||||
# Node-level sync: blocking execution
|
||||
ChatflowMemoryService._update_node_memory_sync(
|
||||
tenant_id, app_id, memory_block_spec, node_id,
|
||||
str(conversation_id), variable_pool, is_draft
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
node_id=node_id,
|
||||
conversation_id=conversation_id,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft
|
||||
)
|
||||
else:
|
||||
# Node-level async: execute asynchronously
|
||||
ChatflowMemoryService._update_node_memory_async(
|
||||
tenant_id, app_id, memory_block_spec, node_id,
|
||||
llm_output, str(conversation_id), variable_pool, is_draft
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
node_id=node_id,
|
||||
conversation_id=conversation_id,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -364,12 +345,14 @@ class ChatflowMemoryService:
|
|||
|
||||
# Node-level async update method
|
||||
@staticmethod
|
||||
def _update_node_memory_async(tenant_id: str, app_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
node_id: str, llm_output: str,
|
||||
conversation_id: str,
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool = False):
|
||||
def _update_node_memory_async(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
node_id: str,
|
||||
conversation_id: str,
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool = False):
|
||||
"""Asynchronously update node memory (submit task)"""
|
||||
|
||||
# Execute update asynchronously using thread
|
||||
|
|
@ -380,7 +363,6 @@ class ChatflowMemoryService:
|
|||
'tenant_id': tenant_id,
|
||||
'app_id': app_id,
|
||||
'node_id': node_id,
|
||||
'llm_output': llm_output,
|
||||
'variable_pool': variable_pool,
|
||||
'is_draft': is_draft
|
||||
},
|
||||
|
|
@ -390,10 +372,15 @@ class ChatflowMemoryService:
|
|||
# Return immediately without waiting
|
||||
|
||||
@staticmethod
|
||||
def _perform_node_memory_update(*, memory_block_spec: MemoryBlockSpec,
|
||||
tenant_id: str, app_id: str, node_id: str,
|
||||
llm_output: str, variable_pool: VariablePool,
|
||||
is_draft: bool = False):
|
||||
def _perform_node_memory_update(
|
||||
*,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool = False
|
||||
):
|
||||
ChatflowMemoryService._perform_memory_update(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
|
|
@ -422,35 +409,36 @@ class ChatflowMemoryService:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _perform_memory_update(tenant_id: str, app_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
conversation_id: str, variable_pool: VariablePool,
|
||||
node_id: Optional[str] = None,
|
||||
is_draft: bool = False):
|
||||
"""Perform the actual memory update using LLM"""
|
||||
def _perform_memory_update(
|
||||
tenant_id: str, app_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
conversation_id: str,
|
||||
variable_pool: VariablePool,
|
||||
node_id: Optional[str],
|
||||
is_draft: bool):
|
||||
history = ChatflowHistoryService.get_visible_chat_history(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# Get current memory value
|
||||
current_memory = ChatflowMemoryService._get_memory_from_chatflow_table(
|
||||
memory_id=memory_block_spec.id,
|
||||
memory_block = ChatflowMemoryService.get_memory_by_spec(
|
||||
tenant_id=tenant_id,
|
||||
spec=memory_block_spec,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id
|
||||
node_id=node_id,
|
||||
is_draft=is_draft
|
||||
)
|
||||
updated_value = LLMGenerator.update_memory_block(
|
||||
tenant_id=tenant_id,
|
||||
visible_history=ChatflowMemoryService._format_chat_history(history),
|
||||
memory_block=memory_block,
|
||||
memory_spec=memory_block_spec,
|
||||
)
|
||||
|
||||
current_value = current_memory.value if current_memory else memory_block_spec.template
|
||||
|
||||
|
||||
|
||||
# Save updated memory
|
||||
updated_memory = MemoryBlock(
|
||||
id=current_memory.id if current_memory else "",
|
||||
id=memory_block.id,
|
||||
memory_id=memory_block_spec.id,
|
||||
name=memory_block_spec.name,
|
||||
value=updated_value,
|
||||
|
|
@ -460,74 +448,17 @@ class ChatflowMemoryService:
|
|||
conversation_id=conversation_id if memory_block_spec.term == MemoryTerm.SESSION else None,
|
||||
node_id=node_id
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, tenant_id, variable_pool, is_draft)
|
||||
|
||||
# Not implemented yet: Send success event
|
||||
# self._send_memory_update_event(memory_block_spec.id, "completed", updated_value)
|
||||
|
||||
@staticmethod
|
||||
def _invoke_llm_for_memory_update(tenant_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
prompt: str, current_value: str) -> Optional[str]:
|
||||
"""Invoke LLM to update memory content
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
memory_block_spec: Memory block specification
|
||||
prompt: Update prompt
|
||||
current_value: Current memory value (used for fallback on failure)
|
||||
|
||||
Returns:
|
||||
Updated value, returns None if failed
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
# Use model configuration defined in memory_block_spec, use default model if not specified
|
||||
if hasattr(memory_block_spec, 'model') and memory_block_spec.model:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=memory_block_spec.model.get("provider", ""),
|
||||
model=memory_block_spec.model.get("name", "")
|
||||
)
|
||||
model_parameters = memory_block_spec.model.get("completion_params", {})
|
||||
else:
|
||||
# Use default model
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
model_parameters = {"temperature": 0.7, "max_tokens": 1000}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content=prompt)],
|
||||
model_parameters=model_parameters,
|
||||
stream=False
|
||||
)
|
||||
)
|
||||
return response.message.get_text_content()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update memory using LLM", exc_info=e)
|
||||
# Not implemented yet: Send failure event
|
||||
# ChatflowMemoryService._send_memory_update_event(memory_block_spec.id, "failed", current_value, str(e))
|
||||
return None
|
||||
|
||||
|
||||
def _send_memory_update_event(self, memory_id: str, status: str, value: str, error: str = ""):
|
||||
"""Send memory update event
|
||||
|
||||
Note: Event system integration not implemented yet, this method is retained as a placeholder
|
||||
"""
|
||||
# Not implemented yet: Event system integration will be added in future versions
|
||||
pass
|
||||
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
|
||||
result = []
|
||||
for message in messages:
|
||||
result.append((str(message.role.value), message.get_text_content()))
|
||||
return result
|
||||
|
||||
# App-level sync batch update related methods
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -756,6 +756,7 @@ def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: boo
|
|||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue