diff --git a/api/controllers/service_api/app/chatflow_memory.py b/api/controllers/service_api/app/chatflow_memory.py index 5484e18f3b..20ffed672e 100644 --- a/api/controllers/service_api/app/chatflow_memory.py +++ b/api/controllers/service_api/app/chatflow_memory.py @@ -1,13 +1,9 @@ from flask_restx import Resource, reqparse -from sqlalchemy.orm import Session from controllers.service_api import api from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.memory.entities import MemoryBlock from core.workflow.entities.variable_pool import VariablePool -from libs.helper import uuid_value -from models import db -from models.chatflow_memory import ChatflowMemoryVariable from services.chatflow_memory_service import ChatflowMemoryService from services.workflow_service import WorkflowService @@ -31,6 +27,7 @@ class MemoryListApi(Resource): result = [it for it in result if it.memory_id == memory_id] return [it for it in result if it.end_user_visible] + class MemoryEditApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def put(self, app_model): @@ -44,6 +41,8 @@ class MemoryEditApi(Resource): update = args.get("update") conversation_id = args.get("conversation_id") node_id = args.get("node_id") + if not isinstance(update, str): + return {'error': 'Invalid update'}, 400 if not workflow: return {'error': 'Workflow not found'}, 404 memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None) @@ -63,5 +62,6 @@ class MemoryEditApi(Resource): ) return '', 204 + api.add_resource(MemoryListApi, '/memories') api.add_resource(MemoryEditApi, '/memory-edit') diff --git a/api/controllers/web/chatflow_memory.py b/api/controllers/web/chatflow_memory.py index c56f50dc35..d0952190cb 100644 --- a/api/controllers/web/chatflow_memory.py +++ b/api/controllers/web/chatflow_memory.py @@ -1,14 +1,9 @@ from flask_restx import reqparse -from sqlalchemy.orm.session import Session -from sympy import false from controllers.web import api from controllers.web.wraps import WebApiResource from core.memory.entities import MemoryBlock from core.workflow.entities.variable_pool import VariablePool -from libs.helper import uuid_value -from models import db -from models.chatflow_memory import ChatflowMemoryVariable from services.chatflow_memory_service import ChatflowMemoryService from services.workflow_service import WorkflowService @@ -31,6 +26,7 @@ class MemoryListApi(WebApiResource): result = [it for it in result if it.memory_id == memory_id] return [it for it in result if it.end_user_visible] + class MemoryEditApi(WebApiResource): def put(self, app_model): parser = reqparse.RequestParser() @@ -43,6 +39,8 @@ class MemoryEditApi(WebApiResource): update = args.get("update") conversation_id = args.get("conversation_id") node_id = args.get("node_id") + if not isinstance(update, str): + return {'error': 'Update must be a string'}, 400 if not workflow: return {'error': 'Workflow not found'}, 404 memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None) @@ -64,5 +62,6 @@ class MemoryEditApi(WebApiResource): ) return '', 204 + api.add_resource(MemoryListApi, '/memories') api.add_resource(MemoryEditApi, '/memory-edit') diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index ef9ce7f6ac..c20e4c645e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -28,7 +28,7 @@ from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent +from core.workflow.graph_events import GraphRunSucceededEvent from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry @@ -223,6 +223,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not assistant_message: logger.warning("Chatflow output does not contain 'answer'.") return + if not isinstance(assistant_message, str): + logger.warning("Chatflow output 'answer' is not a string.") + return try: self._sync_conversation_to_chatflow_tables(assistant_message) except Exception as e: diff --git a/api/core/memory/entities.py b/api/core/memory/entities.py index 974d6c8f29..c42a12c2f3 100644 --- a/api/core/memory/entities.py +++ b/api/core/memory/entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from typing import Optional from uuid import uuid4 @@ -7,23 +7,23 @@ from pydantic import BaseModel, Field from core.app.app_config.entities import ModelConfig -class MemoryScope(str, Enum): +class MemoryScope(StrEnum): """Memory scope determined by node_id field""" APP = "app" # node_id is None NODE = "node" # node_id is not None -class MemoryTerm(str, Enum): +class MemoryTerm(StrEnum): """Memory term determined by conversation_id field""" SESSION = "session" # conversation_id is not None PERSISTENT = "persistent" # conversation_id is None -class MemoryStrategy(str, Enum): +class MemoryStrategy(StrEnum): ON_TURNS = "on_turns" -class MemoryScheduleMode(str, Enum): +class MemoryScheduleMode(StrEnum): SYNC = "sync" ASYNC = "async" @@ -69,6 +69,7 @@ class MemoryBlock(BaseModel): conversation_id: Optional[str] = None node_id: Optional[str] = None + class MemoryBlockWithVisibility(BaseModel): id: str name: str diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index c06fe4d1dd..0a50cccbca 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -202,9 +202,10 @@ class ArrayFileSegment(ArraySegment): def text(self) -> str: return "" + class VersionedMemoryValue(BaseModel): current_value: str = None # type: ignore - versions: Mapping[str, str] = dict() + versions: Mapping[str, str] = {} model_config = ConfigDict(frozen=True) @@ -215,7 +216,7 @@ class VersionedMemoryValue(BaseModel): ) -> "VersionedMemoryValue": if version_name is None: version_name = str(len(self.versions) + 1) - if version_name in self.versions.keys(): + if version_name in self.versions: raise ValueError(f"Version '{version_name}' already exists.") self.current_value = new_value return VersionedMemoryValue( @@ -226,6 +227,7 @@ class VersionedMemoryValue(BaseModel): } ) + class VersionedMemorySegment(Segment): value_type: SegmentType = SegmentType.VERSIONED_MEMORY value: VersionedMemoryValue = None # type: ignore diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 95789a68d0..5cda5c35e5 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -22,7 +22,8 @@ from .segments import ( ObjectSegment, Segment, StringSegment, - get_segment_discriminator, VersionedMemorySegment, + VersionedMemorySegment, + get_segment_discriminator, ) from .types import SegmentType @@ -105,9 +106,11 @@ class BooleanVariable(BooleanSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + class VersionedMemoryVariable(VersionedMemorySegment, Variable): pass + class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): pass diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e550d30476..e1fa3ce627 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -83,7 +83,6 @@ class VariablePool(BaseModel): for memory_id, memory_value in self.memory_blocks.items(): self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value) - def add(self, selector: Sequence[str], value: Any, /): """ Add a variable to the variable pool. diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 8615d6739e..1c3a3ac264 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1235,7 +1235,7 @@ class LLMNode(Node): memory_blocks = workflow.memory_blocks for block_id in block_ids: - memory_block_spec = next((block for block in memory_blocks if block.id == block_id),None) + memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None) if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE: is_draft = (self.invoke_from == InvokeFrom.DEBUGGER) diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 35f83d5799..e510b27b42 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -20,7 +20,9 @@ from core.variables.segments import ( NoneSegment, ObjectSegment, Segment, - StringSegment, VersionedMemorySegment, VersionedMemoryValue, + StringSegment, + VersionedMemorySegment, + VersionedMemoryValue, ) from core.variables.types import SegmentType from core.variables.variables import ( @@ -38,7 +40,8 @@ from core.variables.variables import ( ObjectVariable, SecretVariable, StringVariable, - Variable, VersionedMemoryVariable, + Variable, + VersionedMemoryVariable, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, diff --git a/api/services/chatflow_memory_service.py b/api/services/chatflow_memory_service.py index 7b12c2266d..079b7c31f8 100644 --- a/api/services/chatflow_memory_service.py +++ b/api/services/chatflow_memory_service.py @@ -32,6 +32,7 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) + class ChatflowMemoryService: @staticmethod def get_persistent_memories( @@ -186,9 +187,9 @@ class ChatflowMemoryService: ChatflowMemoryVariable.memory_id == spec.id, ChatflowMemoryVariable.tenant_id == tenant_id, ChatflowMemoryVariable.app_id == app_id, - ChatflowMemoryVariable.node_id == \ + ChatflowMemoryVariable.node_id == (node_id if spec.scope == MemoryScope.NODE else None), - ChatflowMemoryVariable.conversation_id == \ + ChatflowMemoryVariable.conversation_id == (conversation_id if spec.term == MemoryTerm.SESSION else None), ) ).order_by(ChatflowMemoryVariable.version.desc()).limit(1) @@ -517,6 +518,7 @@ class ChatflowMemoryService: result.append((str(message.role.value), message.get_text_content())) return result + def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str: """Generate Redis lock key for memory sync updates diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 76b101d3b3..cc33042950 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1055,6 +1055,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia else: raise Exception("unreachable") + def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: bool) -> Mapping[str, str]: memory_blocks = {} memory_block_specs = workflow.memory_blocks