mirror of https://github.com/langgenius/dify.git
feat: add endpoints to delete memory
This commit is contained in:
parent
75c221038d
commit
3d7d4182a6
|
|
@ -78,5 +78,29 @@ class MemoryEditApi(Resource):
|
|||
return '', 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get('id')
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(
|
||||
app_model,
|
||||
memory_id,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
api.add_resource(MemoryDeleteApi, '/memories')
|
||||
|
|
|
|||
|
|
@ -78,5 +78,28 @@ class MemoryEditApi(WebApiResource):
|
|||
return '', 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(WebApiResource):
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get('id')
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(
|
||||
app_model,
|
||||
memory_id,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
api.add_resource(MemoryDeleteApi, '/memories')
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
|
|
@ -554,6 +554,57 @@ class ChatflowMemoryService:
|
|||
)
|
||||
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
|
||||
|
||||
@staticmethod
|
||||
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
|
||||
workflow = WorkflowService().get_published_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
|
||||
if not memory_spec or not memory_spec.end_user_editable:
|
||||
raise ValueError("Memory not found or not deletable")
|
||||
|
||||
if created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by_id = created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by_id = created_by.id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = delete(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.memory_id == memory_id,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
|
||||
if created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by_id = created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by_id = created_by.id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = delete(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
|
||||
result = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue