diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..8b6b8f227b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream, agent=True, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e2e6c11480..617515945b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,8 @@ +import base64 import logging import time from collections.abc import Generator, Mapping, Sequence +from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ModelConfigWithCredentialsEntity, ) -from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, +) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.file.enums import FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError @@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform -from models.model import App, AppMode, Message, MessageAnnotation +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: from core.file.models import File @@ -203,6 +215,9 @@ class AppRunner: queue_manager: AppQueueManager, stream: bool, agent: bool = False, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result @@ -210,21 +225,41 @@ class AppRunner: :param queue_manager: application queue manager :param stream: stream :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ if not stream and isinstance(invoke_result, LLMResult): - self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_direct( + invoke_result=invoke_result, + queue_manager=queue_manager, + ) elif stream and isinstance(invoke_result, Generator): - self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_stream( + invoke_result=invoke_result, + queue_manager=queue_manager, + agent=agent, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + ) else: raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") - def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): + def _handle_invoke_result_direct( + self, + invoke_result: LLMResult, + queue_manager: AppQueueManager, + ): """ Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ queue_manager.publish( @@ -235,13 +270,22 @@ class AppRunner: ) def _handle_invoke_result_stream( - self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool + self, + invoke_result: Generator[LLMResultChunk, None, None], + queue_manager: AppQueueManager, + agent: bool, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ model: str = "" @@ -259,12 +303,26 @@ class AppRunner: text += message.content elif isinstance(message.content, list): for content in message.content: - if not isinstance(content, str): - # TODO(QuantumGhost): Add multimodal output support for easy ui. - _logger.warning("received multimodal output, type=%s", type(content)) + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - text += content # failback to str + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model @@ -289,6 +347,101 @@ class AppRunner: PublishFrom.APPLICATION_MANAGER, ) + def _handle_multimodal_image_content( + self, + content: ImagePromptMessageContent, + message_id: str, + user_id: str, + tenant_id: str, + queue_manager: AppQueueManager, + ): + """ + Handle multimodal image content from LLM response. + Save the image and create a MessageFile record. + + :param content: ImagePromptMessageContent instance + :param message_id: message id + :param user_id: user id + :param tenant_id: tenant id + :param queue_manager: queue manager + :return: + """ + _logger.info("Handling multimodal image content for message %s", message_id) + + image_url = content.url + base64_data = content.base64_data + + _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data) + + if not image_url and not base64_data: + _logger.warning("Image content has neither URL nor base64 data") + return + + tool_file_manager = ToolFileManager() + + # Save the image file + try: + if image_url: + # Download image from URL + _logger.info("Downloading image from URL: %s", image_url) + tool_file = tool_file_manager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=image_url, + conversation_id=None, + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + elif base64_data: + if base64_data.startswith("data:"): + base64_data = base64_data.split(",", 1)[1] + + image_binary = base64.b64decode(base64_data) + mimetype = content.mime_type or "image/png" + extension = guess_extension(mimetype) or ".png" + + tool_file = tool_file_manager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=image_binary, + mimetype=mimetype, + filename=f"generated_image{extension}", + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + else: + return + except Exception: + _logger.exception("Failed to save image file") + return + + # Create MessageFile record + message_file = MessageFile( + message_id=message_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=f"/files/tools/{tool_file.id}", + upload_file_id=tool_file.id, + created_by_role=( + CreatorUserRole.ACCOUNT + if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} + else CreatorUserRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + # Publish QueueMessageFileEvent + queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file.id), + PublishFrom.APPLICATION_MANAGER, + ) + + _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id) + def moderation_for_inputs( self, *, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f8338b226b..7d1a4c619f 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ddfb5725b4..a872c2e1f7 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..6c997753fa 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -39,6 +39,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + StreamEvent, StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + _precomputed_event_type: StreamEvent | None = None def __init__( self, @@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) + # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks + if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: + self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( + message_id=self._message_id + ) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, - event_type=event_type, + event_type=self._precomputed_event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0e7f300cee..2d4ee08daf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import exists, select +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -30,6 +30,7 @@ from core.app.entities.task_entities import ( StreamEvent, WorkflowTaskState, ) +from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db @@ -57,13 +58,15 @@ class MessageCycleManager: self._message_has_file: set[str] = set() def get_message_event_type(self, message_id: str) -> StreamEvent: + # Fast path: cached determination from prior QueueMessageFileEvent if message_id in self._message_has_file: return StreamEvent.MESSAGE_FILE - with Session(db.engine, expire_on_commit=False) as session: - has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + # Use SQLAlchemy 2.x style session.scalar(select(...)) + with session_factory.create_session() as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) - if has_file: + if message_file: self._message_has_file.add(message_id) return StreamEvent.MESSAGE_FILE @@ -199,6 +202,8 @@ class MessageCycleManager: message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) if message_file and message_file.url is not None: + self._message_has_file.add(message_file.message_id) + # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py new file mode 100644 index 0000000000..421a5246eb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -0,0 +1,454 @@ +"""Test multimodal image output handling in BaseAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from core.file.enums import FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from models.enums import CreatorUserRole + + +class TestBaseAppRunnerMultimodal: + """Test that BaseAppRunner correctly handles multimodal image content.""" + + @pytest.fixture + def mock_user_id(self): + """Mock user ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_tenant_id(self): + """Mock tenant ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_message_id(self): + """Mock message ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = MagicMock() + manager.invoke_from = InvokeFrom.SERVICE_API + return manager + + @pytest.fixture + def mock_tool_file(self): + """Create a mock tool file.""" + tool_file = MagicMock() + tool_file.id = str(uuid4()) + return tool_file + + @pytest.fixture + def mock_message_file(self): + """Create a mock message file.""" + message_file = MagicMock() + message_file.id = str(uuid4()) + return message_file + + def test_handle_multimodal_image_content_with_url( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from URL.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from URL + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + # Verify message file was created with correct parameters + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify database operations + mock_session.add.assert_called_once_with(mock_message_file) + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once_with(mock_message_file) + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id + # publish_from might be passed as positional or keyword argument + assert ( + publish_call[0][1] == PublishFrom.APPLICATION_MANAGER + or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER + ) + + def test_handle_multimodal_image_content_with_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data.""" + # Arrange + import base64 + + # Create a small test image (1x1 PNG) + test_image_data = base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + ).decode() + content = ImagePromptMessageContent( + base64_data=test_image_data, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from base64 + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + # Verify message file was created + mock_msg_file_class.assert_called_once() + + # Verify database operations + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + + def test_handle_multimodal_image_content_with_base64_data_uri( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data with URI prefix.""" + # Arrange + # Data URI format: data:image/png;base64, + test_image_data = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + content = ImagePromptMessageContent( + base64_data=f"data:image/png;base64,{test_image_data}", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify that base64 data was extracted correctly (without prefix) + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + # The base64 data should be decoded, so we check the binary was passed + assert "file_binary" in call_kwargs + + def test_handle_multimodal_image_content_without_url_or_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content without URL or base64 data.""" + # Arrange + content = ImagePromptMessageContent( + url="", + base64_data="", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create any files or publish events + mock_mgr_class.assert_not_called() + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_with_error( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content when an error occurs.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock to raise exception + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.side_effect = Exception("Network error") + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + # Should not raise exception, just log it + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create message file or publish event on error + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_debugger_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that debugger mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is ACCOUNT for debugger mode + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + + def test_handle_multimodal_image_content_service_api_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that service API mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is END_USER for service API + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5ef7f0d7f4..5a43a247e3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,7 +1,6 @@ """Unit tests for the message cycle manager optimization.""" -from types import SimpleNamespace -from unittest.mock import ANY, Mock, patch +from unittest.mock import Mock, patch import pytest from flask import current_app @@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization: def test_get_message_event_type_with_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE_FILE when message has files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute with current_app.app_context(): @@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and no message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # Execute with current_app.app_context(): @@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response with current_app.app_context(): @@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): """Test that message_to_stream_response skips database query when event_type is provided.""" - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Execute with event_type provided result = message_cycle_manager.message_to_stream_response( answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE @@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE - # Should not query database when event_type is provided - mock_session_class.assert_not_called() + # Should not open a session when event_type is provided + mock_session_factory.create_session.assert_not_called() def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): """Test message_to_stream_response with from_variable_selector parameter.""" @@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization: def test_optimization_usage_example(self, message_cycle_manager): """Test the optimization pattern that should be used by callers.""" # Step 1: Get event type once (this queries database) - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None # No files + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # No files with current_app.app_context(): event_type = message_cycle_manager.get_message_event_type("test-message-id") - # Should query database once - mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + # Should open session once + mock_session_factory.create_session.assert_called_once() assert event_type == StreamEvent.MESSAGE # Step 2: Use event_type for multiple calls (no additional queries) - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: - mock_session_class.return_value.__enter__.return_value = Mock() + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + mock_session_factory.create_session.return_value.__enter__.return_value = Mock() chunk1_response = message_cycle_manager.message_to_stream_response( answer="Chunk 1", message_id="test-message-id", event_type=event_type @@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization: answer="Chunk 2", message_id="test-message-id", event_type=event_type ) - # Should not query database again - mock_session_class.assert_not_called() + # Should not open session again when event_type provided + mock_session_factory.create_session.assert_not_called() assert chunk1_response.event == StreamEvent.MESSAGE assert chunk2_response.event == StreamEvent.MESSAGE diff --git a/web/app/components/base/chat/chat/hooks.multimodal.spec.ts b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts new file mode 100644 index 0000000000..2975d62887 --- /dev/null +++ b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts @@ -0,0 +1,178 @@ +/** + * Tests for multimodal image file handling in chat hooks. + * Tests the file object conversion logic without full hook integration. + */ + +describe('Multimodal File Handling', () => { + describe('File type to MIME type mapping', () => { + it('should map image to image/png', () => { + const fileType: string = 'image' + const expectedMime = 'image/png' + const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map video to video/mp4', () => { + const fileType: string = 'video' + const expectedMime = 'video/mp4' + const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map audio to audio/mpeg', () => { + const fileType: string = 'audio' + const expectedMime = 'audio/mpeg' + const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map unknown to application/octet-stream', () => { + const fileType: string = 'unknown' + const expectedMime = 'application/octet-stream' + const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + }) + + describe('TransferMethod selection', () => { + it('should select remote_url for images', () => { + const fileType: string = 'image' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('remote_url') + }) + + it('should select local_file for non-images', () => { + const fileType: string = 'video' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('local_file') + }) + }) + + describe('File extension mapping', () => { + it('should use .png extension for images', () => { + const fileType: string = 'image' + const expectedExtension = '.png' + const extension = fileType === 'image' ? 'png' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp4 extension for videos', () => { + const fileType: string = 'video' + const expectedExtension = '.mp4' + const extension = fileType === 'video' ? 'mp4' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp3 extension for audio', () => { + const fileType: string = 'audio' + const expectedExtension = '.mp3' + const extension = fileType === 'audio' ? 'mp3' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + }) + + describe('File name generation', () => { + it('should generate correct file name for images', () => { + const fileType: string = 'image' + const expectedName = 'generated_image.png' + const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for videos', () => { + const fileType: string = 'video' + const expectedName = 'generated_video.mp4' + const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for audio', () => { + const fileType: string = 'audio' + const expectedName = 'generated_audio.mp3' + const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + }) + + describe('SupportFileType mapping', () => { + it('should map image type to image supportFileType', () => { + const fileType: string = 'image' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('image') + }) + + it('should map video type to video supportFileType', () => { + const fileType: string = 'video' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('video') + }) + + it('should map audio type to audio supportFileType', () => { + const fileType: string = 'audio' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('audio') + }) + + it('should map unknown type to document supportFileType', () => { + const fileType: string = 'unknown' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('document') + }) + }) + + describe('File conversion logic', () => { + it('should detect existing transferMethod', () => { + const fileWithTransferMethod = { + id: 'file-123', + transferMethod: 'remote_url' as const, + type: 'image/png', + name: 'test.png', + size: 1024, + supportFileType: 'image', + progress: 100, + } + const hasTransferMethod = 'transferMethod' in fileWithTransferMethod + expect(hasTransferMethod).toBe(true) + }) + + it('should detect missing transferMethod', () => { + const fileWithoutTransferMethod = { + id: 'file-456', + type: 'image', + url: 'http://example.com/image.png', + belongs_to: 'assistant', + } + const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod + expect(hasTransferMethod).toBe(false) + }) + + it('should create file with size 0 for generated files', () => { + const expectedSize = 0 + expect(expectedSize).toBe(0) + }) + }) + + describe('Agent vs Non-Agent mode logic', () => { + it('should check for agent_thoughts to determine mode', () => { + const agentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [{}], + } + const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(true) + }) + + it('should detect non-agent mode when agent_thoughts is empty', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [], + } + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(false) + }) + + it('should detect non-agent mode when agent_thoughts is undefined', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = {} + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBeFalsy() + }) + }) +}) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 9b8a9b11dc..182aeebdbb 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -419,9 +419,40 @@ export const useChat = ( } }, onFile(file) { + // Convert simple file type to MIME type for non-agent mode + // Backend sends: { id, type: "image", belongs_to, url } + // Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size } + + // Determine file type for MIME conversion + const fileType = (file as { type?: string }).type || 'image' + + // If file already has transferMethod, use it as base and ensure all required fields exist + // Otherwise, create a new complete file object + const baseFile = ('transferMethod' in file) ? (file as Partial) : null + + const convertedFile: FileEntity = { + id: baseFile?.id || (file as { id: string }).id, + type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'), + transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'), + uploadedId: baseFile?.uploadedId || (file as { id: string }).id, + supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'), + progress: baseFile?.progress ?? 100, + name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`, + url: baseFile?.url || (file as { url?: string }).url, + size: baseFile?.size ?? 0, // Generated files don't have a known size + } + + // For agent mode, add files to the last thought const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1] - if (lastThought) - responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file] + if (lastThought) { + const thought = lastThought as { message_files?: FileEntity[] } + responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile] + } + // For non-agent mode, add files directly to responseItem.message_files + else { + const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? [] + responseItem.message_files = [...currentFiles, convertedFile] + } updateCurrentQAOnTree({ placeholderQuestionId, diff --git a/web/app/components/datasets/hit-testing/index.spec.tsx b/web/app/components/datasets/hit-testing/index.spec.tsx index 6bab3afb6a..07a78cd55f 100644 --- a/web/app/components/datasets/hit-testing/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/index.spec.tsx @@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => { renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Find submit button by class @@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Component should still be functional - check for the main container @@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - const { container } = renderWithProviders() + const { container: _container } = renderWithProviders() + + // Wait for textbox to be rendered with timeout for CI environment + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Submit @@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - // Verify the component is still rendered after submission - expect(container.firstChild).toBeInTheDocument() + // Wait for the mutation to complete + await waitFor( + () => { + expect(mockHitTestingMutateAsync).toHaveBeenCalled() + }, + { timeout: 3000 }, + ) }) it('should render ResultItem components for non-external results', async () => { @@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - const { container } = renderWithProviders() + const { container: _container } = renderWithProviders() + + // Wait for component to be fully rendered with longer timeout + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) // Submit a query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) const buttons = screen.getAllByRole('button') @@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - // Verify component is rendered after submission - expect(container.firstChild).toBeInTheDocument() + // Wait for mutation to complete with longer timeout + await waitFor( + () => { + expect(mockHitTestingMutateAsync).toHaveBeenCalled() + }, + { timeout: 3000 }, + ) }) it('should render external results when dataset is external', async () => { @@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => { // Component should render expect(container.firstChild).toBeInTheDocument() + + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type in textarea to verify component is functional - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) const buttons = screen.getAllByRole('button') @@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - await waitFor(() => { - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) + // Verify component is still functional after submission + await waitFor( + () => { + expect(screen.getByRole('textbox')).toBeInTheDocument() + }, + { timeout: 3000 }, + ) }) }) @@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Enter query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test query' } }) // Submit @@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Enter query and submit - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test query' } }) const buttons = screen.getAllByRole('button') @@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => { // Wait for state updates await waitFor(() => { expect(container.firstChild).toBeInTheDocument() - }, { timeout: 2000 }) + }, { timeout: 3000 }) // Verify mutation was called expect(mockHitTestingMutateAsync).toHaveBeenCalled() @@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Submit a query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test' } }) const buttons = screen.getAllByRole('button') @@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => { // Verify the component renders await waitFor(() => { expect(container.firstChild).toBeInTheDocument() - }) + }, { timeout: 3000 }) }) }) diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 654b667deb..1c0c700177 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({ getMarketplaceUrl: (path: string, _params?: Record) => `https://marketplace.dify.ai${path}`, })) +// Mock marketplace client used by marketplace utils +vi.mock('@/service/client', () => ({ + marketplaceClient: { + collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + collections: [ + { + name: 'collection-1', + label: { 'en-US': 'Collection 1' }, + description: { 'en-US': 'Desc' }, + rule: '', + created_at: '2024-01-01', + updated_at: '2024-01-01', + searchable: true, + search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' }, + }, + ], + }, + })), + collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + plugins: [ + { type: 'plugin', org: 'test', name: 'plugin1', tags: [] }, + ], + }, + })), + // Some utils paths may call searchAdvanced; provide a minimal stub + searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + plugins: [ + { type: 'plugin', org: 'test', name: 'plugin1', tags: [] }, + ], + total: 1, + }, + })), + }, +})) + // Mock context/query-client vi.mock('@/context/query-client', () => ({ TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) =>
{children}
, @@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => { // ================================ // Async Utils Tests // ================================ + +// Narrow mock surface and avoid any in tests +// Types are local to this spec to keep scope minimal + +type FnMock = ReturnType + +type MarketplaceClientMock = { + collectionPlugins: FnMock + collections: FnMock +} + describe('Async Utils', () => { + let marketplaceClientMock: MarketplaceClientMock + + beforeAll(async () => { + const mod = await import('@/service/client') + marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock + }) beforeEach(() => { vi.clearAllMocks() }) @@ -1490,12 +1545,10 @@ describe('Async Utils', () => { { type: 'plugin', org: 'test', name: 'plugin2' }, ] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Adjusted to our mocked marketplaceClient instead of fetch + marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection', { @@ -1504,12 +1557,13 @@ describe('Async Utils', () => { type: 'plugin', }) - expect(globalThis.fetch).toHaveBeenCalled() + expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled() expect(result).toHaveLength(2) }) it('should handle fetch error and return empty array', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + // Simulate error from client + marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error')) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection') @@ -1519,25 +1573,18 @@ describe('Async Utils', () => { it('should pass abort signal when provided', async () => { const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Our client mock receives the signal as second arg + marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const controller = new AbortController() const { getMarketplacePluginsByCollectionId } = await import('./utils') await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalledWith( - expect.any(Request), - expect.any(Object), - ) - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('test-collection') + expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled() + const call = marketplaceClientMock.collectionPlugins.mock.calls[0] + expect(call[1]).toMatchObject({ signal: controller.signal }) }) }) @@ -1548,23 +1595,17 @@ describe('Async Utils', () => { ] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - let callCount = 0 - globalThis.fetch = vi.fn().mockImplementation(() => { - callCount++ - if (callCount === 1) { - return Promise.resolve( - new Response(JSON.stringify({ data: { collections: mockCollections } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Simulate two-step client calls: collections then collectionPlugins + let stage = 0 + marketplaceClientMock.collections.mockImplementationOnce(async () => { + stage = 1 + return { data: { collections: mockCollections } } + }) + marketplaceClientMock.collectionPlugins.mockImplementation(async () => { + if (stage === 1) { + return { data: { plugins: mockPlugins } } } - return Promise.resolve( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + return { data: { plugins: [] } } }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') @@ -1578,7 +1619,8 @@ describe('Async Utils', () => { }) it('should handle fetch error and return empty data', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + // Simulate client error + marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error')) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const result = await getMarketplaceCollectionsAndPlugins() @@ -1588,24 +1630,16 @@ describe('Async Utils', () => { }) it('should append condition and type to URL when provided', async () => { - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { collections: [] } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) - + // Assert that the client was called with query containing condition/type const { getMarketplaceCollectionsAndPlugins } = await import('./utils') await getMarketplaceCollectionsAndPlugins({ condition: 'category=tool', type: 'bundle', }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalled() - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('condition=category%3Dtool') + expect(marketplaceClientMock.collections).toHaveBeenCalled() + const call = marketplaceClientMock.collections.mock.calls[0] + expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) }) }) }) }) diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 2c06b37115..0cc9614cef 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -822,7 +822,7 @@ "count": 2 }, "ts/no-explicit-any": { - "count": 15 + "count": 14 } }, "app/components/base/chat/chat/index.tsx": { diff --git a/web/utils/format.ts b/web/utils/format.ts index 0c81b339a3..ce813d3999 100644 --- a/web/utils/format.ts +++ b/web/utils/format.ts @@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => { : `${formatted}${units[unitIndex].symbol}` } } + // Fallback: if no threshold matched, return the number string + return num.toString() } export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {