feat: add endpoints to delete memory

This commit is contained in:
Stream 2025-09-23 19:07:37 +08:00
parent 75c221038d
commit 3d7d4182a6
No known key found for this signature in database
GPG Key ID: 033728094B100D70
3 changed files with 99 additions and 1 deletions

View File

@ -78,5 +78,29 @@ class MemoryEditApi(Resource):
return '', 204
class MemoryDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -78,5 +78,28 @@ class MemoryEditApi(WebApiResource):
return '', 204
class MemoryDeleteApi(WebApiResource):
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import and_, select
from sqlalchemy import and_, delete, select
from sqlalchemy.orm import Session
from core.llm_generator.llm_generator import LLMGenerator
@ -554,6 +554,57 @@ class ChatflowMemoryService:
)
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
@staticmethod
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
raise ValueError("Workflow not found")
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
if not memory_spec or not memory_spec.end_user_editable:
raise ValueError("Memory not found or not deletable")
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.memory_id == memory_id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
result = []