refactor: make memories API return MemoryBlock

This commit is contained in:
Stream 2025-09-22 17:14:07 +08:00
parent d654d9d8b1
commit a8c2a300f6
No known key found for this signature in database
GPG Key ID: 033728094B100D70
3 changed files with 35 additions and 8 deletions

View File

@ -24,8 +24,8 @@ class MemoryListApi(Resource):
if conversation_id:
result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)]
if memory_id:
result = [it for it in result if it.memory_id == memory_id]
return [it for it in result if it.end_user_visible]
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(Resource):

View File

@ -23,8 +23,8 @@ class MemoryListApi(WebApiResource):
if conversation_id:
result = [*result, *ChatflowMemoryService.get_session_memories(app_model, conversation_id, version)]
if memory_id:
result = [it for it in result if it.memory_id == memory_id]
return [it for it in result if it.end_user_visible]
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(WebApiResource):

View File

@ -39,7 +39,7 @@ class ChatflowMemoryService:
def get_persistent_memories(
app: App,
version: int | None = None
) -> Sequence[MemoryBlockWithVisibility]:
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
@ -60,14 +60,14 @@ class ChatflowMemoryService:
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results])
return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results])
@staticmethod
def get_session_memories(
app: App,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlockWithVisibility]:
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
@ -88,7 +88,7 @@ class ChatflowMemoryService:
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._with_visibility(app, [result[0] for result in db_results])
return ChatflowMemoryService._convert_to_memory_blocks(app, [result[0] for result in db_results])
@staticmethod
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
@ -349,6 +349,33 @@ class ChatflowMemoryService:
conversation_id=conversation_id
)
@staticmethod
def _convert_to_memory_blocks(
app: App,
raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
return []
results = []
for chatflow_memory_variable in raw_results:
spec = next(
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id),
None
)
if spec and chatflow_memory_variable.app_id:
results.append(
MemoryBlock(
spec=spec,
tenant_id=chatflow_memory_variable.tenant_id,
value=MemoryValueData.model_validate_json(chatflow_memory_variable.value).value,
app_id=chatflow_memory_variable.app_id,
conversation_id=chatflow_memory_variable.conversation_id,
node_id=chatflow_memory_variable.node_id
)
)
return results
@staticmethod
def _with_visibility(
app: App,