From 6eab6a675c41cf0e82038a17d2a6a679d7146c81 Mon Sep 17 00:00:00 2001 From: Stream Date: Tue, 23 Sep 2025 16:56:07 +0800 Subject: [PATCH] feat: add created_by to memory blocks --- .../service_api/app/chatflow_memory.py | 28 ++++++--- api/controllers/web/chatflow_memory.py | 28 ++++++--- api/core/app/apps/advanced_chat/app_runner.py | 14 ++++- api/core/memory/entities.py | 6 ++ api/core/workflow/nodes/llm/node.py | 16 +++-- api/models/chatflow_memory.py | 2 + api/services/chatflow_memory_service.py | 61 +++++++++++++++---- api/services/workflow_service.py | 18 ++++-- 8 files changed, 135 insertions(+), 38 deletions(-) diff --git a/api/controllers/service_api/app/chatflow_memory.py b/api/controllers/service_api/app/chatflow_memory.py index ff890eb326..9742193645 100644 --- a/api/controllers/service_api/app/chatflow_memory.py +++ b/api/controllers/service_api/app/chatflow_memory.py @@ -2,27 +2,40 @@ from flask_restx import Resource, reqparse from controllers.service_api import api from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.memory.entities import MemoryBlock +from core.memory.entities import MemoryBlock, MemoryCreatedBy from core.workflow.entities.variable_pool import VariablePool +from models import App, EndUser from services.chatflow_memory_service import ChatflowMemoryService from services.workflow_service import WorkflowService class MemoryListApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def get(self, app_model): + def get(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=False, type=str | None, default=None) parser.add_argument("memory_id", required=False, type=str | None, default=None) parser.add_argument("version", required=False, type=int | None, default=None) args = parser.parse_args() - conversation_id = args.get("conversation_id") + conversation_id: str | None = args.get("conversation_id") memory_id = args.get("memory_id") version = args.get("version") - result = ChatflowMemoryService.get_persistent_memories(app_model, version) + result = ChatflowMemoryService.get_persistent_memories( + app_model, + MemoryCreatedBy(end_user_id=end_user.id), + version + ) if conversation_id: - result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)] + result = [ + *result, + *ChatflowMemoryService.get_session_memories( + app_model, + MemoryCreatedBy(end_user_id=end_user.id), + conversation_id, + version + ) + ] if memory_id: result = [it for it in result if it.spec.id == memory_id] return [it for it in result if it.spec.end_user_visible] @@ -30,7 +43,7 @@ class MemoryListApi(Resource): class MemoryEditApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def put(self, app_model): + def put(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('id', type=str, required=True) parser.add_argument("conversation_id", type=str | None, required=False, default=None) @@ -56,7 +69,8 @@ class MemoryEditApi(Resource): conversation_id=conversation_id, node_id=node_id, app_id=app_model.id, - edited_by_user=True + edited_by_user=True, + created_by=MemoryCreatedBy(end_user_id=end_user.id), ), variable_pool=VariablePool(), is_draft=False diff --git a/api/controllers/web/chatflow_memory.py b/api/controllers/web/chatflow_memory.py index 97d6e28c98..d41925cc63 100644 --- a/api/controllers/web/chatflow_memory.py +++ b/api/controllers/web/chatflow_memory.py @@ -2,33 +2,46 @@ from flask_restx import reqparse from controllers.web import api from controllers.web.wraps import WebApiResource -from core.memory.entities import MemoryBlock +from core.memory.entities import MemoryBlock, MemoryCreatedBy from core.workflow.entities.variable_pool import VariablePool +from models import App, EndUser from services.chatflow_memory_service import ChatflowMemoryService from services.workflow_service import WorkflowService class MemoryListApi(WebApiResource): - def get(self, app_model): + def get(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=False, type=str | None, default=None) parser.add_argument("memory_id", required=False, type=str | None, default=None) parser.add_argument("version", required=False, type=int | None, default=None) args = parser.parse_args() - conversation_id = args.get("conversation_id") + conversation_id: str | None = args.get("conversation_id") memory_id = args.get("memory_id") version = args.get("version") - result = ChatflowMemoryService.get_persistent_memories(app_model, version) + result = ChatflowMemoryService.get_persistent_memories( + app_model, + MemoryCreatedBy(end_user_id=end_user.id), + version + ) if conversation_id: - result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)] + result = [ + *result, + *ChatflowMemoryService.get_session_memories( + app_model, + MemoryCreatedBy(end_user_id=end_user.id), + conversation_id, + version + ) + ] if memory_id: result = [it for it in result if it.spec.id == memory_id] return [it for it in result if it.spec.end_user_visible] class MemoryEditApi(WebApiResource): - def put(self, app_model): + def put(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('id', type=str, required=True) parser.add_argument("conversation_id", type=str | None, required=False, default=None) @@ -56,7 +69,8 @@ class MemoryEditApi(WebApiResource): conversation_id=conversation_id, node_id=node_id, app_id=app_model.id, - edited_by_user=True + edited_by_user=True, + created_by=MemoryCreatedBy(end_user_id=end_user.id) ), variable_pool=VariablePool(), is_draft=False diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c20e4c645e..009c878919 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -21,7 +21,7 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature -from core.memory.entities import MemoryScope +from core.memory.entities import MemoryCreatedBy, MemoryScope from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration @@ -443,7 +443,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_id=self._workflow.app_id, node_id=None, conversation_id=conversation_id, - is_draft=is_draft + is_draft=is_draft, + created_by=self._get_created_by(), ) # Build memory_id -> value mapping @@ -482,5 +483,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, conversation_id=self.conversation.id, variable_pool=variable_pool, - is_draft=is_draft + is_draft=is_draft, + created_by=self._get_created_by() ) + + def _get_created_by(self) -> MemoryCreatedBy: + if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + return MemoryCreatedBy(account_id=self.application_generate_entity.user_id) + else: + return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id) diff --git a/api/core/memory/entities.py b/api/core/memory/entities.py index c80cd88a8c..8ab985ad45 100644 --- a/api/core/memory/entities.py +++ b/api/core/memory/entities.py @@ -49,6 +49,11 @@ class MemoryBlockSpec(BaseModel): end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users") +class MemoryCreatedBy(BaseModel): + end_user_id: str | None = None + account_id: str | None = None + + class MemoryBlock(BaseModel): """Runtime memory block instance @@ -69,6 +74,7 @@ class MemoryBlock(BaseModel): conversation_id: Optional[str] = None node_id: Optional[str] = None edited_by_user: bool = False + created_by: MemoryCreatedBy class MemoryValueData(BaseModel): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1c3a3ac264..f89ae6f6c9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -14,7 +14,7 @@ from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.memory.entities import MemoryScope +from core.memory.entities import MemoryCreatedBy, MemoryScope from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( @@ -74,7 +74,8 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from models import Workflow, db +from models import UserFrom, Workflow +from models.engine import db from services.chatflow_memory_service import ChatflowMemoryService from . import llm_utils @@ -1242,13 +1243,20 @@ class LLMNode(Node): ChatflowMemoryService.update_node_memory_if_needed( tenant_id=self.tenant_id, app_id=self.app_id, - node_id=self.node_id, + node_id=self.id, conversation_id=conversation_id, memory_block_spec=memory_block_spec, variable_pool=variable_pool, - is_draft=is_draft + is_draft=is_draft, + created_by=self._get_user_from_context() ) + def _get_user_from_context(self) -> MemoryCreatedBy: + if self.user_from == UserFrom.ACCOUNT: + return MemoryCreatedBy(account_id=self.user_id) + else: + return MemoryCreatedBy(end_user_id=self.user_id) + def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole diff --git a/api/models/chatflow_memory.py b/api/models/chatflow_memory.py index cde48c5860..773c69405d 100644 --- a/api/models/chatflow_memory.py +++ b/api/models/chatflow_memory.py @@ -26,6 +26,8 @@ class ChatflowMemoryVariable(Base): scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node' term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent' version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1) + created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account` + created_by: Mapped[str] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( diff --git a/api/services/chatflow_memory_service.py b/api/services/chatflow_memory_service.py index 45b602ad39..f67b058015 100644 --- a/api/services/chatflow_memory_service.py +++ b/api/services/chatflow_memory_service.py @@ -11,6 +11,7 @@ from core.llm_generator.llm_generator import LLMGenerator from core.memory.entities import ( MemoryBlock, MemoryBlockSpec, + MemoryCreatedBy, MemoryScheduleMode, MemoryScope, MemoryTerm, @@ -23,7 +24,7 @@ from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool from extensions.ext_database import db from extensions.ext_redis import redis_client -from models import App +from models import App, CreatorUserRole from models.chatflow_memory import ChatflowMemoryVariable from models.workflow import Workflow, WorkflowDraftVariable from services.chatflow_history_service import ChatflowHistoryService @@ -37,15 +38,24 @@ class ChatflowMemoryService: @staticmethod def get_persistent_memories( app: App, + created_by: MemoryCreatedBy, version: int | None = None ) -> Sequence[MemoryBlock]: + if created_by.account_id: + created_by_role = CreatorUserRole.ACCOUNT + created_by_id = created_by.account_id + else: + created_by_role = CreatorUserRole.END_USER + created_by_id = created_by.id if version is None: # If version not specified, get the latest version stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where( and_( ChatflowMemoryVariable.tenant_id == app.tenant_id, ChatflowMemoryVariable.app_id == app.id, - ChatflowMemoryVariable.conversation_id == None + ChatflowMemoryVariable.conversation_id == None, + ChatflowMemoryVariable.created_by_role == created_by_role, + ChatflowMemoryVariable.created_by == created_by_id, ) ).order_by(ChatflowMemoryVariable.version.desc()) else: @@ -54,16 +64,19 @@ class ChatflowMemoryService: ChatflowMemoryVariable.tenant_id == app.tenant_id, ChatflowMemoryVariable.app_id == app.id, ChatflowMemoryVariable.conversation_id == None, + ChatflowMemoryVariable.created_by_role == created_by_role, + ChatflowMemoryVariable.created_by == created_by_id, ChatflowMemoryVariable.version == version ) ) with Session(db.engine) as session: db_results = session.execute(stmt).all() - return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results]) + return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results]) @staticmethod def get_session_memories( app: App, + created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None ) -> Sequence[MemoryBlock]: @@ -87,12 +100,18 @@ class ChatflowMemoryService: ) with Session(db.engine) as session: db_results = session.execute(stmt).all() - return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results]) + return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results]) @staticmethod def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None: key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value) + if memory.created_by.account_id: + created_by_role = CreatorUserRole.ACCOUNT + created_by = memory.created_by.account_id + else: + created_by_role = CreatorUserRole.END_USER + created_by = memory.created_by.id with Session(db.engine) as session: existing = session.query(ChatflowMemoryVariable).filter_by( @@ -100,7 +119,9 @@ class ChatflowMemoryService: tenant_id=memory.tenant_id, app_id=memory.app_id, node_id=memory.node_id, - conversation_id=memory.conversation_id + conversation_id=memory.conversation_id, + created_by_role=created_by_role, + created_by=created_by, ).order_by(ChatflowMemoryVariable.version.desc()).first() new_version = 1 if not existing else existing.version + 1 session.add( @@ -118,6 +139,8 @@ class ChatflowMemoryService: term=memory.spec.term, scope=memory.spec.scope, version=new_version, + created_by_role=created_by_role, + created_by=created_by, ) ) session.commit() @@ -149,12 +172,13 @@ class ChatflowMemoryService: def get_memories_by_specs( memory_block_specs: Sequence[MemoryBlockSpec], tenant_id: str, app_id: str, + created_by: MemoryCreatedBy, conversation_id: Optional[str], node_id: Optional[str], is_draft: bool ) -> Sequence[MemoryBlock]: return [ChatflowMemoryService.get_memory_by_spec( - spec, tenant_id, app_id, conversation_id, node_id, is_draft + spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft ) for spec in memory_block_specs] @staticmethod @@ -162,6 +186,7 @@ class ChatflowMemoryService: spec: MemoryBlockSpec, tenant_id: str, app_id: str, + created_by: MemoryCreatedBy, conversation_id: Optional[str], node_id: Optional[str], is_draft: bool @@ -183,7 +208,8 @@ class ChatflowMemoryService: app_id=app_id, conversation_id=conversation_id, node_id=node_id, - spec=spec + spec=spec, + created_by=created_by, ) stmt = select(ChatflowMemoryVariable).where( and_( @@ -206,7 +232,8 @@ class ChatflowMemoryService: conversation_id=conversation_id, node_id=node_id, spec=spec, - edited_by_user=memory_value_data.edited_by_user + edited_by_user=memory_value_data.edited_by_user, + created_by=created_by, ) return MemoryBlock( tenant_id=tenant_id, @@ -214,7 +241,8 @@ class ChatflowMemoryService: app_id=app_id, conversation_id=conversation_id, node_id=node_id, - spec=spec + spec=spec, + created_by=created_by, ) @staticmethod @@ -222,6 +250,7 @@ class ChatflowMemoryService: workflow: Workflow, conversation_id: str, variable_pool: VariablePool, + created_by: MemoryCreatedBy, is_draft: bool ): visible_messages = ChatflowHistoryService.get_visible_chat_history( @@ -240,7 +269,8 @@ class ChatflowMemoryService: app_id=workflow.app_id, conversation_id=conversation_id, node_id=None, - is_draft=is_draft + is_draft=is_draft, + created_by=created_by, ) if ChatflowMemoryService._should_update_memory(memory, visible_messages): if memory.spec.schedule_mode == MemoryScheduleMode.SYNC: @@ -276,6 +306,7 @@ class ChatflowMemoryService: tenant_id: str, app_id: str, node_id: str, + created_by: MemoryCreatedBy, conversation_id: str, memory_block_spec: MemoryBlockSpec, variable_pool: VariablePool, @@ -293,7 +324,8 @@ class ChatflowMemoryService: app_id=app_id, conversation_id=conversation_id, node_id=node_id, - is_draft=is_draft + is_draft=is_draft, + created_by=created_by, ) if not ChatflowMemoryService._should_update_memory( memory_block=memory_block, @@ -356,6 +388,7 @@ class ChatflowMemoryService: @staticmethod def _convert_to_memory_blocks( app: App, + created_by: MemoryCreatedBy, raw_results: Sequence[ChatflowMemoryVariable] ) -> Sequence[MemoryBlock]: workflow = WorkflowService().get_published_workflow(app) @@ -377,7 +410,8 @@ class ChatflowMemoryService: app_id=chatflow_memory_variable.app_id, conversation_id=chatflow_memory_variable.conversation_id, node_id=chatflow_memory_variable.node_id, - edited_by_user=memory_value_data.edited_by_user + edited_by_user=memory_value_data.edited_by_user, + created_by=created_by, ) ) return results @@ -515,7 +549,8 @@ class ChatflowMemoryService: app_id=memory_block.app_id, conversation_id=memory_block.conversation_id, node_id=memory_block.node_id, - edited_by_user=False + edited_by_user=False, + created_by=memory_block.created_by, ) ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index cc33042950..4137f2df91 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.memory.entities import MemoryScope +from core.memory.entities import MemoryCreatedBy, MemoryScope from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion @@ -1008,7 +1008,6 @@ def _setup_variable_pool( system_variable.dialogue_count = 1 else: system_variable = SystemVariable.empty() - # init variable pool variable_pool = VariablePool( system_variables=system_variable, @@ -1017,7 +1016,12 @@ def _setup_variable_pool( # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. conversation_variables=cast(list[VariableUnion], conversation_variables), # - memory_blocks=_fetch_memory_blocks(workflow, conversation_id, is_draft=is_draft), + memory_blocks=_fetch_memory_blocks( + workflow, + MemoryCreatedBy(account_id=user_id), + conversation_id, + is_draft=is_draft + ), ) return variable_pool @@ -1056,7 +1060,12 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia raise Exception("unreachable") -def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: bool) -> Mapping[str, str]: +def _fetch_memory_blocks( + workflow: Workflow, + created_by: MemoryCreatedBy, + conversation_id: str, + is_draft: bool +) -> Mapping[str, str]: memory_blocks = {} memory_block_specs = workflow.memory_blocks memories = ChatflowMemoryService.get_memories_by_specs( @@ -1066,6 +1075,7 @@ def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: boo node_id=None, conversation_id=conversation_id, is_draft=is_draft, + created_by=created_by, ) for memory in memories: if memory.spec.scope == MemoryScope.APP: