feat: add created_by to memory blocks

This commit is contained in:
Stream 2025-09-23 16:56:07 +08:00
parent d94e598a89
commit 6eab6a675c
No known key found for this signature in database
GPG Key ID: 033728094B100D70
8 changed files with 135 additions and 38 deletions

View File

@ -2,27 +2,40 @@ from flask_restx import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.memory.entities import MemoryBlock
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def get(self, app_model):
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id = args.get("conversation_id")
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
result = ChatflowMemoryService.get_persistent_memories(app_model, version)
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if conversation_id:
result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)]
result = [
*result,
*ChatflowMemoryService.get_session_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
]
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
@ -30,7 +43,7 @@ class MemoryListApi(Resource):
class MemoryEditApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def put(self, app_model):
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
@ -56,7 +69,8 @@ class MemoryEditApi(Resource):
conversation_id=conversation_id,
node_id=node_id,
app_id=app_model.id,
edited_by_user=True
edited_by_user=True,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
),
variable_pool=VariablePool(),
is_draft=False

View File

@ -2,33 +2,46 @@ from flask_restx import reqparse
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.memory.entities import MemoryBlock
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(WebApiResource):
def get(self, app_model):
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id = args.get("conversation_id")
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
result = ChatflowMemoryService.get_persistent_memories(app_model, version)
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if conversation_id:
result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)]
result = [
*result,
*ChatflowMemoryService.get_session_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
]
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(WebApiResource):
def put(self, app_model):
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
@ -56,7 +69,8 @@ class MemoryEditApi(WebApiResource):
conversation_id=conversation_id,
node_id=node_id,
app_id=app_model.id,
edited_by_user=True
edited_by_user=True,
created_by=MemoryCreatedBy(end_user_id=end_user.id)
),
variable_pool=VariablePool(),
is_draft=False

View File

@ -21,7 +21,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.memory.entities import MemoryScope
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
@ -443,7 +443,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_id=self._workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft
is_draft=is_draft,
created_by=self._get_created_by(),
)
# Build memory_id -> value mapping
@ -482,5 +483,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow=self._workflow,
conversation_id=self.conversation.id,
variable_pool=variable_pool,
is_draft=is_draft
is_draft=is_draft,
created_by=self._get_created_by()
)
def _get_created_by(self) -> MemoryCreatedBy:
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
else:
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)

View File

@ -49,6 +49,11 @@ class MemoryBlockSpec(BaseModel):
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
class MemoryCreatedBy(BaseModel):
end_user_id: str | None = None
account_id: str | None = None
class MemoryBlock(BaseModel):
"""Runtime memory block instance
@ -69,6 +74,7 @@ class MemoryBlock(BaseModel):
conversation_id: Optional[str] = None
node_id: Optional[str] = None
edited_by_user: bool = False
created_by: MemoryCreatedBy
class MemoryValueData(BaseModel):

View File

@ -14,7 +14,7 @@ from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryScope
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -74,7 +74,8 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from models import Workflow, db
from models import UserFrom, Workflow
from models.engine import db
from services.chatflow_memory_service import ChatflowMemoryService
from . import llm_utils
@ -1242,13 +1243,20 @@ class LLMNode(Node):
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.node_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:
return MemoryCreatedBy(account_id=self.user_id)
else:
return MemoryCreatedBy(end_user_id=self.user_id)
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@ -26,6 +26,8 @@ class ChatflowMemoryVariable(Base):
scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node'
term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent'
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account`
created_by: Mapped[str] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(

View File

@ -11,6 +11,7 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.memory.entities import (
MemoryBlock,
MemoryBlockSpec,
MemoryCreatedBy,
MemoryScheduleMode,
MemoryScope,
MemoryTerm,
@ -23,7 +24,7 @@ from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import App
from models import App, CreatorUserRole
from models.chatflow_memory import ChatflowMemoryVariable
from models.workflow import Workflow, WorkflowDraftVariable
from services.chatflow_history_service import ChatflowHistoryService
@ -37,15 +38,24 @@ class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
app: App,
created_by: MemoryCreatedBy,
version: int | None = None
) -> Sequence[MemoryBlock]:
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
).order_by(ChatflowMemoryVariable.version.desc())
else:
@ -54,16 +64,19 @@ class ChatflowMemoryService:
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
ChatflowMemoryVariable.version == version
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results])
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def get_session_memories(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlock]:
@ -87,12 +100,18 @@ class ChatflowMemoryService:
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results])
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id
variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value)
if memory.created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by = memory.created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by = memory.created_by.id
with Session(db.engine) as session:
existing = session.query(ChatflowMemoryVariable).filter_by(
@ -100,7 +119,9 @@ class ChatflowMemoryService:
tenant_id=memory.tenant_id,
app_id=memory.app_id,
node_id=memory.node_id,
conversation_id=memory.conversation_id
conversation_id=memory.conversation_id,
created_by_role=created_by_role,
created_by=created_by,
).order_by(ChatflowMemoryVariable.version.desc()).first()
new_version = 1 if not existing else existing.version + 1
session.add(
@ -118,6 +139,8 @@ class ChatflowMemoryService:
term=memory.spec.term,
scope=memory.spec.scope,
version=new_version,
created_by_role=created_by_role,
created_by=created_by,
)
)
session.commit()
@ -149,12 +172,13 @@ class ChatflowMemoryService:
def get_memories_by_specs(
memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str, app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
) -> Sequence[MemoryBlock]:
return [ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, conversation_id, node_id, is_draft
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
) for spec in memory_block_specs]
@staticmethod
@ -162,6 +186,7 @@ class ChatflowMemoryService:
spec: MemoryBlockSpec,
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
@ -183,7 +208,8 @@ class ChatflowMemoryService:
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec
spec=spec,
created_by=created_by,
)
stmt = select(ChatflowMemoryVariable).where(
and_(
@ -206,7 +232,8 @@ class ChatflowMemoryService:
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
edited_by_user=memory_value_data.edited_by_user
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
)
return MemoryBlock(
tenant_id=tenant_id,
@ -214,7 +241,8 @@ class ChatflowMemoryService:
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec
spec=spec,
created_by=created_by,
)
@staticmethod
@ -222,6 +250,7 @@ class ChatflowMemoryService:
workflow: Workflow,
conversation_id: str,
variable_pool: VariablePool,
created_by: MemoryCreatedBy,
is_draft: bool
):
visible_messages = ChatflowHistoryService.get_visible_chat_history(
@ -240,7 +269,8 @@ class ChatflowMemoryService:
app_id=workflow.app_id,
conversation_id=conversation_id,
node_id=None,
is_draft=is_draft
is_draft=is_draft,
created_by=created_by,
)
if ChatflowMemoryService._should_update_memory(memory, visible_messages):
if memory.spec.schedule_mode == MemoryScheduleMode.SYNC:
@ -276,6 +306,7 @@ class ChatflowMemoryService:
tenant_id: str,
app_id: str,
node_id: str,
created_by: MemoryCreatedBy,
conversation_id: str,
memory_block_spec: MemoryBlockSpec,
variable_pool: VariablePool,
@ -293,7 +324,8 @@ class ChatflowMemoryService:
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
is_draft=is_draft
is_draft=is_draft,
created_by=created_by,
)
if not ChatflowMemoryService._should_update_memory(
memory_block=memory_block,
@ -356,6 +388,7 @@ class ChatflowMemoryService:
@staticmethod
def _convert_to_memory_blocks(
app: App,
created_by: MemoryCreatedBy,
raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
@ -377,7 +410,8 @@ class ChatflowMemoryService:
app_id=chatflow_memory_variable.app_id,
conversation_id=chatflow_memory_variable.conversation_id,
node_id=chatflow_memory_variable.node_id,
edited_by_user=memory_value_data.edited_by_user
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
)
)
return results
@ -515,7 +549,8 @@ class ChatflowMemoryService:
app_id=memory_block.app_id,
conversation_id=memory_block.conversation_id,
node_id=memory_block.node_id,
edited_by_user=False
edited_by_user=False,
created_by=memory_block.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)

View File

@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.memory.entities import MemoryScope
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
@ -1008,7 +1008,6 @@ def _setup_variable_pool(
system_variable.dialogue_count = 1
else:
system_variable = SystemVariable.empty()
# init variable pool
variable_pool = VariablePool(
system_variables=system_variable,
@ -1017,7 +1016,12 @@ def _setup_variable_pool(
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
memory_blocks=_fetch_memory_blocks(workflow, conversation_id, is_draft=is_draft),
memory_blocks=_fetch_memory_blocks(
workflow,
MemoryCreatedBy(account_id=user_id),
conversation_id,
is_draft=is_draft
),
)
return variable_pool
@ -1056,7 +1060,12 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
raise Exception("unreachable")
def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: bool) -> Mapping[str, str]:
def _fetch_memory_blocks(
workflow: Workflow,
created_by: MemoryCreatedBy,
conversation_id: str,
is_draft: bool
) -> Mapping[str, str]:
memory_blocks = {}
memory_block_specs = workflow.memory_blocks
memories = ChatflowMemoryService.get_memories_by_specs(
@ -1066,6 +1075,7 @@ def _fetch_memory_blocks(workflow: Workflow, conversation_id: str, is_draft: boo
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=created_by,
)
for memory in memories:
if memory.spec.scope == MemoryScope.APP: