From 0310f631ee966c728c06cc64a52db4b39c5d08d6 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 11 Feb 2026 10:57:27 +0800 Subject: [PATCH] fix: fix get_message_event_type return wrong message type (#32019) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../easy_ui_based_generate_task_pipeline.py | 84 ++++++++++++++++++- .../task_pipeline/message_cycle_manager.py | 8 +- ...test_message_cycle_manager_optimization.py | 37 +++++++- 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6c997753fa..833f32fc7d 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -45,6 +45,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk +from core.file import helpers as file_helpers +from core.file.enums import FileTransferMethod from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( @@ -56,10 +58,11 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.tools.signature import sign_tool_file from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.model import AppMode, Conversation, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile logger = logging.getLogger(__name__) @@ -463,6 +466,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): metadata=metadata_dict, ) + def _record_files(self): + with Session(db.engine, expire_on_commit=False) as session: + message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() + if not message_files: + return None + + files_list = [] + upload_file_ids = [ + mf.upload_file_id + for mf in message_files + if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id + ] + upload_files_map = {} + if upload_file_ids: + upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() + upload_files_map = {uf.id: uf for uf in upload_files} + + for message_file in message_files: + upload_file = None + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: + upload_file = upload_files_map.get(message_file.upload_file_id) + + url = None + filename = "file" + mime_type = "application/octet-stream" + size = 0 + extension = "" + + if message_file.transfer_method == FileTransferMethod.REMOTE_URL: + url = message_file.url + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params + elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + # Fallback: generate URL even if upload_file not found + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: + # For tool files, use URL directly if it's HTTP, otherwise sign it + if message_file.url.startswith("http"): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + else: + # Extract tool file id and extension from URL + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] # Remove query params first + # Use rsplit to correctly handle filenames with multiple dots + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + else: + tool_file_id = file_part + extension = ".bin" + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + + transfer_method_value = message_file.transfer_method + remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" + file_dict = { + "related_id": message_file.id, + "extension": extension, + "filename": filename, + "size": size, + "mime_type": mime_type, + "transfer_method": transfer_method_value, + "type": message_file.type, + "url": url or "", + "upload_file_id": message_file.upload_file_id or message_file.id, + "remote_url": remote_url, + } + files_list.append(file_dict) + return files_list or None + def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: """ Agent message to stream response. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index d682083f34..cc4f97ad94 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -64,7 +64,13 @@ class MessageCycleManager: # Use SQLAlchemy 2.x style session.scalar(select(...)) with session_factory.create_session() as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) + message_file = session.scalar( + select(MessageFile) + .where( + MessageFile.message_id == message_id, + ) + .where(MessageFile.belongs_to == "assistant") + ) if message_file: self._message_has_file.add(message_id) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5a43a247e3..c0c636715d 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -25,15 +25,19 @@ class TestMessageCycleManagerOptimization: task_state = Mock() return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) - def test_get_message_event_type_with_message_file(self, message_cycle_manager): - """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + def test_get_message_event_type_with_assistant_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has assistant-generated files. + + This ensures that AI-generated images (belongs_to='assistant') trigger the MESSAGE_FILE event, + allowing the frontend to properly display generated image files with url field. + """ with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.scalar(select(...)) + mock_message_file.belongs_to = "assistant" mock_session.scalar.return_value = mock_message_file # Execute @@ -44,6 +48,31 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE_FILE mock_session.scalar.assert_called_once() + def test_get_message_event_type_with_user_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message only has user-uploaded files. + + This is a regression test for the issue where user-uploaded images (belongs_to='user') + caused the LLM text response to be incorrectly tagged with MESSAGE_FILE event, + resulting in broken images in the chat UI. The query filters for belongs_to='assistant', + so when only user files exist, the database query returns None, resulting in MESSAGE event type. + """ + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + # Setup mock session and message file + mock_session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # When querying for assistant files with only user files present, return None + # (simulates database query with belongs_to='assistant' filter returning no results) + mock_session.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.scalar.assert_called_once() + def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -69,7 +98,7 @@ class TestMessageCycleManagerOptimization: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.scalar(select(...)) + mock_message_file.belongs_to = "assistant" mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response