From a8c2a300f643004a6b820984f794fa6739e700f8 Mon Sep 17 00:00:00 2001 From: Stream Date: Mon, 22 Sep 2025 17:14:07 +0800 Subject: [PATCH] refactor: make memories API return MemoryBlock --- .../service_api/app/chatflow_memory.py | 4 +-- api/controllers/web/chatflow_memory.py | 4 +-- api/services/chatflow_memory_service.py | 35 ++++++++++++++++--- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/api/controllers/service_api/app/chatflow_memory.py b/api/controllers/service_api/app/chatflow_memory.py index 20ffed672e..c321814392 100644 --- a/api/controllers/service_api/app/chatflow_memory.py +++ b/api/controllers/service_api/app/chatflow_memory.py @@ -24,8 +24,8 @@ class MemoryListApi(Resource): if conversation_id: result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)] if memory_id: - result = [it for it in result if it.memory_id == memory_id] - return [it for it in result if it.end_user_visible] + result = [it for it in result if it.spec.id == memory_id] + return [it for it in result if it.spec.end_user_visible] class MemoryEditApi(Resource): diff --git a/api/controllers/web/chatflow_memory.py b/api/controllers/web/chatflow_memory.py index d0952190cb..6883760552 100644 --- a/api/controllers/web/chatflow_memory.py +++ b/api/controllers/web/chatflow_memory.py @@ -23,8 +23,8 @@ class MemoryListApi(WebApiResource): if conversation_id: result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)] if memory_id: - result = [it for it in result if it.memory_id == memory_id] - return [it for it in result if it.end_user_visible] + result = [it for it in result if it.spec.id == memory_id] + return [it for it in result if it.spec.end_user_visible] class MemoryEditApi(WebApiResource): diff --git a/api/services/chatflow_memory_service.py b/api/services/chatflow_memory_service.py index 57a1d8df3f..ac2da9be8f 100644 --- a/api/services/chatflow_memory_service.py +++ b/api/services/chatflow_memory_service.py @@ -39,7 +39,7 @@ class ChatflowMemoryService: def get_persistent_memories( app: App, version: int | None = None - ) -> Sequence[MemoryBlockWithVisibility]: + ) -> Sequence[MemoryBlock]: if version is None: # If version not specified, get the latest version stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where( @@ -60,14 +60,14 @@ class ChatflowMemoryService: ) with Session(db.engine) as session: db_results = session.execute(stmt).all() - return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results]) + return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results]) @staticmethod def get_session_memories( app: App, conversation_id: str, version: int | None = None - ) -> Sequence[MemoryBlockWithVisibility]: + ) -> Sequence[MemoryBlock]: if version is None: # If version not specified, get the latest version stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where( @@ -88,7 +88,7 @@ class ChatflowMemoryService: ) with Session(db.engine) as session: db_results = session.execute(stmt).all() - return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results]) + return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results]) @staticmethod def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None: @@ -349,6 +349,33 @@ class ChatflowMemoryService: conversation_id=conversation_id ) + @staticmethod + def _convert_to_memory_blocks( + app: App, + raw_results: Sequence[ChatflowMemoryVariable] + ) -> Sequence[MemoryBlock]: + workflow = WorkflowService().get_published_workflow(app) + if not workflow: + return [] + results = [] + for chatflow_memory_variable in raw_results: + spec = next( + (spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id), + None + ) + if spec and chatflow_memory_variable.app_id: + results.append( + MemoryBlock( + spec=spec, + tenant_id=chatflow_memory_variable.tenant_id, + value=MemoryValueData.model_validate_json(chatflow_memory_variable.value).value, + app_id=chatflow_memory_variable.app_id, + conversation_id=chatflow_memory_variable.conversation_id, + node_id=chatflow_memory_variable.node_id + ) + ) + return results + @staticmethod def _with_visibility( app: App,