mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
fix: reduce db roundtrips of message update (#36213)
This commit is contained in:
parent
3f7a68fc77
commit
27b084c4d4
@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@ -425,11 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
|
||||
message.workflow_run_id = run_id
|
||||
session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id))
|
||||
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
||||
@ -234,9 +234,19 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
# Track database operations for verification
|
||||
executed_statements = []
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
sess = SimpleNamespace()
|
||||
|
||||
def _execute(stmt):
|
||||
executed_statements.append(stmt)
|
||||
return SimpleNamespace()
|
||||
|
||||
sess.execute = _execute
|
||||
yield sess
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
|
||||
@ -246,6 +256,14 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
assert pipeline._workflow_run_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
# Verify database operation was executed
|
||||
assert len(executed_statements) == 1
|
||||
# Verify the UPDATE statement targets the correct message and sets workflow_run_id
|
||||
update_stmt = executed_statements[0]
|
||||
stmt_str = str(update_stmt)
|
||||
assert "UPDATE messages" in stmt_str
|
||||
assert "WHERE messages.id" in stmt_str
|
||||
|
||||
def test_message_end_to_stream_response_strips_annotation_reply(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user