fix: reduce db roundtrips of message update (#36213)

This commit is contained in:
Yunlu Wen 2026-05-15 16:39:48 +08:00 committed by GitHub
parent 3f7a68fc77
commit 27b084c4d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from datetime import datetime
from threading import Thread
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy import select, update
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -425,11 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = run_id
session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id))
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,

View File

@ -234,9 +234,19 @@ class TestAdvancedChatGenerateTaskPipeline:
)
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
# Track database operations for verification
executed_statements = []
@contextmanager
def _fake_session():
yield SimpleNamespace()
sess = SimpleNamespace()
def _execute(stmt):
executed_statements.append(stmt)
return SimpleNamespace()
sess.execute = _execute
yield sess
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
@ -246,6 +256,14 @@ class TestAdvancedChatGenerateTaskPipeline:
assert pipeline._workflow_run_id == "run-id"
assert responses == ["started"]
# Verify database operation was executed
assert len(executed_statements) == 1
# Verify the UPDATE statement targets the correct message and sets workflow_run_id
update_stmt = executed_statements[0]
stmt_str = str(update_stmt)
assert "UPDATE messages" in stmt_str
assert "WHERE messages.id" in stmt_str
def test_message_end_to_stream_response_strips_annotation_reply(self):
pipeline = _make_pipeline()
pipeline._task_state.metadata.annotation_reply = AnnotationReply(