From 27b084c4d4eba04da0e436c54f4d9e351fb13097 Mon Sep 17 00:00:00 2001 From: Yunlu Wen Date: Fri, 15 May 2026 16:39:48 +0800 Subject: [PATCH] fix: reduce db roundtrips of message update (#36213) --- .../advanced_chat/generate_task_pipeline.py | 8 ++------ .../test_generate_task_pipeline_core.py | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 82dbf5381d..3c46f91e51 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,7 +9,7 @@ from datetime import datetime from threading import Thread from typing import Any, Union -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -425,11 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id = run_id with self._database_session() as session: - message = self._get_message(session=session) - if not message: - raise ValueError(f"Message not found: {self._message_id}") - - message.workflow_run_id = run_id + session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id)) workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index d8f794b483..d2e1ceb69d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -234,9 +234,19 @@ class TestAdvancedChatGenerateTaskPipeline: ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + # Track database operations for verification + executed_statements = [] + @contextmanager def _fake_session(): - yield SimpleNamespace() + sess = SimpleNamespace() + + def _execute(stmt): + executed_statements.append(stmt) + return SimpleNamespace() + + sess.execute = _execute + yield sess monkeypatch.setattr(pipeline, "_database_session", _fake_session) monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace()) @@ -246,6 +256,14 @@ class TestAdvancedChatGenerateTaskPipeline: assert pipeline._workflow_run_id == "run-id" assert responses == ["started"] + # Verify database operation was executed + assert len(executed_statements) == 1 + # Verify the UPDATE statement targets the correct message and sets workflow_run_id + update_stmt = executed_statements[0] + stmt_str = str(update_stmt) + assert "UPDATE messages" in stmt_str + assert "WHERE messages.id" in stmt_str + def test_message_end_to_stream_response_strips_annotation_reply(self): pipeline = _make_pipeline() pipeline._task_state.metadata.annotation_reply = AnnotationReply(