diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py index 9e46d72893..402bfdc606 100644 --- a/api/dify_graph/model_runtime/entities/message_entities.py +++ b/api/dify_graph/model_runtime/entities/message_entities.py @@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage): :return: True if prompt message is empty, False otherwise """ - if not super().is_empty() and not self.tool_call_id: - return False - - return True + return super().is_empty() and not self.tool_call_id diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py index 80cf01fb6c..1a57078b98 100644 --- a/api/dify_graph/model_runtime/errors/invoke.py +++ b/api/dify_graph/model_runtime/errors/invoke.py @@ -4,7 +4,8 @@ class InvokeError(ValueError): description: str | None = None def __init__(self, description: str | None = None): - self.description = description + if description is not None: + self.description = description def __str__(self): return self.description or self.__class__.__name__ diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py index e168fc11d1..de0677a348 100644 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -282,7 +282,8 @@ class ModelProviderFactory: all_model_type_models.append(model_schema) simple_provider_schema = provider_schema.to_simple_provider() - simple_provider_schema.models.extend(all_model_type_models) + if model_type: + simple_provider_schema.models = all_model_type_models providers.append(simple_provider_schema) diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py new file mode 100644 index 0000000000..5ecfe01808 --- /dev/null +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -0,0 +1,969 @@ +"""Comprehensive unit tests for core/memory/token_buffer_memory.py""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory +from dify_graph.model_runtime.entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) +from models.model import AppMode + +# --------------------------------------------------------------------------- +# Helpers / shared fixtures +# --------------------------------------------------------------------------- + + +def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock: + """Return a minimal Conversation mock.""" + conv = MagicMock() + conv.id = str(uuid4()) + conv.mode = mode + conv.model_config = {} + return conv + + +def _make_model_instance() -> MagicMock: + """Return a ModelInstance mock whose token counter returns a constant.""" + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 100 + return mi + + +def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock: + msg = MagicMock() + msg.id = str(uuid4()) + msg.query = "user query" + msg.answer = answer + msg.answer_tokens = answer_tokens + msg.workflow_run_id = str(uuid4()) + msg.created_at = MagicMock() + return msg + + +# =========================================================================== +# Tests for __init__ and workflow_run_repo property +# =========================================================================== + + +class TestInit: + def test_init_stores_conversation_and_model_instance(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + assert mem.conversation is conv + assert mem.model_instance is mi + assert mem._workflow_run_repo is None + + def test_workflow_run_repo_is_created_lazily(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + mock_repo = MagicMock() + with ( + patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm, + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=mock_repo, + ), + ): + mock_db.engine = MagicMock() + repo = mem.workflow_run_repo + assert repo is mock_repo + assert mem._workflow_run_repo is mock_repo + + def test_workflow_run_repo_cached_after_first_access(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + existing_repo = MagicMock() + mem._workflow_run_repo = existing_repo + + with patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_factory: + repo = mem.workflow_run_repo + mock_factory.assert_not_called() + assert repo is existing_repo + + +# =========================================================================== +# Tests for _build_prompt_message_with_files +# =========================================================================== + + +class TestBuildPromptMessageWithFiles: + """Tests for the private _build_prompt_message_with_files method.""" + + # ------------------------------------------------------------------ + # Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch) + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_user_message(self, mode): + """When file_extra_config is falsy or app_record is None → plain UserPromptMessage.""" + conv = _make_conversation(mode) + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, # falsy → file_objs = [] + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="hello", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_assistant_message(self, mode): + """Plain AssistantPromptMessage when no files and is_user_message=False.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="ai reply", + message=_make_message(), + app_record=None, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert result.content == "ai reply" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_user_message(self, mode): + """When files are present, returns UserPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None # no detail override + + mock_file_obj = MagicMock() + # Must be a real entity so Pydantic's tagged union discriminator can validate it + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + mock_message_file = MagicMock() + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[mock_message_file], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert isinstance(result.content, list) + # Last element should be TextPromptMessageContent + assert isinstance(result.content[-1], TextPromptMessageContent) + assert result.content[-1].data == "user text" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_assistant_message(self, mode): + """When files are present, returns AssistantPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None + + mock_file_obj = MagicMock() + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="ai text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert isinstance(result.content, list) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_image_detail_overridden(self, mode): + """When image_config.detail is set, detail is taken from config.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_image_config = MagicMock() + mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = mock_image_config + + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=MagicMock(), + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ) as mock_to_prompt, + ): + mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + # Ensure the LOW detail was passed through + mock_to_prompt.assert_called_once_with( + mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode): + """app_record=None path → file_objs stays empty → plain messages.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="hello", + message=_make_message(), + app_record=None, # <-- forces the else branch → file_objs = [] + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + # ------------------------------------------------------------------ + # Mode: ADVANCED_CHAT / WORKFLOW + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_app_raises(self, mode): + """Raises ValueError when conversation.app is falsy.""" + conv = _make_conversation(mode) + conv.app = None + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="App not found for conversation"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_workflow_run_id_raises(self, mode): + """Raises ValueError when message.workflow_run_id is falsy.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + message = _make_message() + message.workflow_run_id = None # force missing + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="Workflow run ID not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=message, + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_run_not_found_raises(self, mode): + """Raises ValueError when workflow_run_repo returns None.""" + conv = _make_conversation(mode) + mock_app = MagicMock() + conv.app = mock_app + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = None + + with pytest.raises(ValueError, match="Workflow run not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_not_found_raises(self, mode): + """Raises ValueError when Workflow lookup returns None.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + ): + mock_db.session.scalar.return_value = None # workflow not found + + with pytest.raises(ValueError, match="Workflow not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_success_no_files_user(self, mode): + """Happy path: workflow mode, no message files → plain UserPromptMessage.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.features_dict = {} + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalar.return_value = mock_workflow + + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="wf text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "wf text" + + # ------------------------------------------------------------------ + # Invalid mode + # ------------------------------------------------------------------ + + def test_invalid_mode_raises_assertion(self): + """Any unknown AppMode raises AssertionError.""" + conv = _make_conversation() + conv.mode = "unknown_mode" # not in any set + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(AssertionError, match="Invalid app mode"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + +# =========================================================================== +# Tests for get_history_prompt_messages +# =========================================================================== + + +class TestGetHistoryPromptMessages: + """Tests for get_history_prompt_messages.""" + + def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory: + conv = _make_conversation(mode) + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_returns_empty_when_no_messages(self): + mem = self._make_memory() + with patch("core.memory.token_buffer_memory.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + assert result == [] + + def test_skips_first_message_without_answer(self): + """The newest message (index 0 after extraction) without answer and tokens==0 is skipped.""" + mem = self._make_memory() + + msg_no_answer = _make_message(answer="", answer_tokens=0) + msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg_no_answer], + ), + ): + mock_db.session.scalars.return_value.all.side_effect = [ + [msg_no_answer], # first call: messages query + [], # second call: user files query (never hit, but safe) + ] + result = mem.get_history_prompt_messages() + + assert result == [] + + def test_message_with_answer_not_skipped(self): + """A message with a non-empty answer is NOT popped.""" + mem = self._make_memory() + + msg = _make_message(answer="some answer", answer_tokens=10) + msg.parent_message_id = None + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + # user files query → empty; assistant files query → empty + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + + assert len(result) == 2 # one user + one assistant + + def test_message_limit_default_is_500(self): + """When message_limit is None the stmt is limited to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=None) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_clipped_to_500(self): + """A message_limit > 500 is clamped to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=9999) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_positive_used(self): + """A positive message_limit < 500 is used as-is.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=10) + mock_stmt.limit.assert_called_with(10) + + def test_message_limit_zero_uses_default(self): + """message_limit=0 triggers the else branch → default 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=0) + mock_stmt.limit.assert_called_with(500) + + def test_user_files_cause_build_with_files_call(self): + """When user_files is non-empty _build_prompt_message_with_files is invoked.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_user_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="from build") + mock_assistant_prompt = AssistantPromptMessage(content="answer") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + # messages query + r.all.return_value = [msg] + elif call_count["n"] == 1: + # user files + r.all.return_value = [mock_user_file] + else: + # assistant files + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + side_effect=[mock_user_prompt, mock_assistant_prompt], + ) as mock_build, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert mock_build.call_count >= 1 + # First call should be user message + first_call_kwargs = mock_build.call_args_list[0][1] + assert first_call_kwargs["is_user_message"] is True + + def test_assistant_files_cause_build_with_files_call(self): + """When assistant_files is non-empty, build is called with is_user_message=False.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_assistant_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="query") + mock_assistant_prompt = AssistantPromptMessage(content="built") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + elif call_count["n"] == 1: + r.all.return_value = [] # no user files + else: + r.all.return_value = [mock_assistant_file] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + return_value=mock_assistant_prompt, + ) as mock_build, + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + mock_build.assert_called_once() + call_kwargs = mock_build.call_args[1] + assert call_kwargs["is_user_message"] is False + + def test_token_pruning_removes_oldest_messages(self): + """If tokens exceed limit, oldest messages are removed until within limit.""" + conv = _make_conversation() + conv.app = MagicMock() + + # Model returns tokens that decrease only after removing pairs + token_values = [3000, 1500] # first call over limit, second within + mi = MagicMock() + mi.get_llm_num_tokens.side_effect = token_values + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + # After pruning, we should have fewer than the 2 initial messages + assert len(result) <= 1 + + def test_token_pruning_stops_at_single_message(self): + """Pruning stops when only 1 message remains (to prevent empty list).""" + conv = _make_conversation() + conv.app = MagicMock() + + # Always over limit + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 99999 + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=1) + + # At least 1 message should remain + assert len(result) >= 1 + + def test_no_pruning_when_within_limit(self): + """When tokens ≤ limit, no pruning occurs.""" + mem = self._make_memory() + mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000 + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + assert len(result) == 2 # user + assistant + + def test_plain_user_and_assistant_messages_returned(self): + """Without files, plain UserPromptMessage and AssistantPromptMessage appear.""" + mem = self._make_memory() + + msg = _make_message(answer="My answer") + msg.query = "My query" + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert len(result) == 2 + user_msg, ai_msg = result + assert isinstance(user_msg, UserPromptMessage) + assert user_msg.content == "My query" + assert isinstance(ai_msg, AssistantPromptMessage) + assert ai_msg.content == "My answer" + + +# =========================================================================== +# Tests for get_history_prompt_text +# =========================================================================== + + +class TestGetHistoryPromptText: + """Tests for get_history_prompt_text.""" + + def _make_memory(self) -> TokenBufferMemory: + conv = _make_conversation() + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_empty_messages_returns_empty_string(self): + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]): + result = mem.get_history_prompt_text() + assert result == "" + + def test_user_and_assistant_messages_formatted(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hello"), + AssistantPromptMessage(content="World"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A") + assert result == "H: Hello\nA: World" + + def test_custom_prefixes_applied(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Bye"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot") + assert "Human: Hi" in result + assert "Bot: Bye" in result + + def test_list_content_with_text_and_image(self): + """List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image].""" + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="caption"), + ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "caption" in result + assert "[image]" in result + + def test_list_content_text_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content=[TextPromptMessageContent(data="just text")]), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "just text" in result + + def test_list_content_image_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "[image]" in result + + def test_unknown_role_skipped(self): + """Messages with a role that is not USER or ASSISTANT are skipped.""" + mem = self._make_memory() + + # Create a mock message with a SYSTEM role + system_msg = MagicMock() + system_msg.role = PromptMessageRole.SYSTEM + system_msg.content = "system instruction" + + user_msg = UserPromptMessage(content="hi") + + with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]): + result = mem.get_history_prompt_text() + + assert "system instruction" not in result + assert "Human: hi" in result + + def test_passes_max_token_limit_and_message_limit(self): + """Parameters are forwarded to get_history_prompt_messages.""" + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get: + mem.get_history_prompt_text(max_token_limit=500, message_limit=10) + mock_get.assert_called_once_with(max_token_limit=500, message_limit=10) + + def test_multiple_messages_joined_by_newline(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Q1"), + AssistantPromptMessage(content="A1"), + UserPromptMessage(content="Q2"), + AssistantPromptMessage(content="A2"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + lines = result.split("\n") + assert len(lines) == 4 + assert lines[0] == "Human: Q1" + assert lines[1] == "Assistant: A1" + assert lines[2] == "Human: Q2" + assert lines[3] == "Assistant: A2" + + def test_assistant_list_content_formatted(self): + """AssistantPromptMessage with list content is also handled.""" + mem = self._make_memory() + messages = [ + AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="response text"), + ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "response text" in result + assert "[image]" in result diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/__init__.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__init__.py rename to api/tests/unit_tests/dify_graph/model_runtime/__init__.py diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py new file mode 100644 index 0000000000..2410d16d63 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py @@ -0,0 +1,964 @@ +"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.base_callback import ( + _TEXT_COLOR_MAPPING, + Callback, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool + +# --------------------------------------------------------------------------- +# Concrete implementation of the abstract Callback for testing +# --------------------------------------------------------------------------- + + +class ConcreteCallback(Callback): + """A minimal concrete subclass that satisfies all abstract methods.""" + + def __init__(self, raise_error: bool = False): + self.raise_error = raise_error + # Track invocations + self.before_invoke_calls: list[dict] = [] + self.new_chunk_calls: list[dict] = [] + self.after_invoke_calls: list[dict] = [] + self.invoke_error_calls: list[dict] = [] + + def on_before_invoke( + self, + llm_instance, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.before_invoke_calls.append( + { + "llm_instance": llm_instance, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + # To cover the 'raise NotImplementedError()' in the base class + try: + super().on_before_invoke( + llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_new_chunk( + self, + llm_instance, + chunk, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.new_chunk_calls.append( + { + "llm_instance": llm_instance, + "chunk": chunk, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_new_chunk( + llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_after_invoke( + self, + llm_instance, + result, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.after_invoke_calls.append( + { + "llm_instance": llm_instance, + "result": result, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_after_invoke( + llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_invoke_error( + self, + llm_instance, + ex, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.invoke_error_calls.append( + { + "llm_instance": llm_instance, + "ex": ex, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_invoke_error( + llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + +# --------------------------------------------------------------------------- +# A subclass that deliberately leaves abstract methods un-implemented, +# used to verify that instantiation raises TypeError. +# --------------------------------------------------------------------------- + + +# =========================================================================== +# Tests for _TEXT_COLOR_MAPPING module-level constant +# =========================================================================== + + +class TestTextColorMapping: + """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" + + def test_contains_all_expected_colors(self): + expected_keys = {"blue", "yellow", "pink", "green", "red"} + assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys + + def test_blue_escape_code(self): + assert _TEXT_COLOR_MAPPING["blue"] == "36;1" + + def test_yellow_escape_code(self): + assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" + + def test_pink_escape_code(self): + assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" + + def test_green_escape_code(self): + assert _TEXT_COLOR_MAPPING["green"] == "32;1" + + def test_red_escape_code(self): + assert _TEXT_COLOR_MAPPING["red"] == "31;1" + + def test_mapping_is_dict(self): + assert isinstance(_TEXT_COLOR_MAPPING, dict) + + def test_all_values_are_strings(self): + for key, value in _TEXT_COLOR_MAPPING.items(): + assert isinstance(value, str), f"Value for {key!r} should be str" + + +# =========================================================================== +# Tests for the Callback ABC itself +# =========================================================================== + + +class TestCallbackAbstract: + """Tests verifying Callback is a proper ABC.""" + + def test_cannot_instantiate_abstract_class_directly(self): + """Callback cannot be instantiated since it has abstract methods.""" + with pytest.raises(TypeError): + Callback() # type: ignore[abstract] + + def test_concrete_subclass_can_be_instantiated(self): + cb = ConcreteCallback() + assert isinstance(cb, Callback) + + def test_default_raise_error_is_false(self): + cb = ConcreteCallback() + assert cb.raise_error is False + + def test_raise_error_can_be_set_to_true(self): + cb = ConcreteCallback(raise_error=True) + assert cb.raise_error is True + + def test_subclass_missing_on_before_invoke_raises_type_error(self): + """A subclass missing any single abstract method cannot be instantiated.""" + + class IncompleteCallback(Callback): + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_new_chunk_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_after_invoke_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_invoke_error_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for the on_before_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.model = "gpt-4" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 0.7} + + def test_on_before_invoke_called_with_required_args(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 1 + call = self.cb.before_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["model"] == self.model + assert call["credentials"] == self.credentials + assert call["prompt_messages"] is self.prompt_messages + assert call["model_parameters"] is self.model_parameters + + def test_on_before_invoke_defaults_tools_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["tools"] is None + + def test_on_before_invoke_defaults_stop_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stop"] is None + + def test_on_before_invoke_defaults_stream_true(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stream"] is True + + def test_on_before_invoke_defaults_user_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["user"] is None + + def test_on_before_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["stop1", "stop2"] + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="user-123", + ) + call = self.cb.before_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "user-123" + + def test_on_before_invoke_called_multiple_times(self): + for i in range(3): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=f"model-{i}", + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 3 + assert self.cb.before_invoke_calls[2]["model"] == "model-2" + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for the on_new_chunk callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.chunk = MagicMock(spec=LLMResultChunk) + self.model = "gpt-3.5-turbo" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"max_tokens": 256} + + def test_on_new_chunk_called_with_required_args(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 1 + call = self.cb.new_chunk_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["chunk"] is self.chunk + assert call["model"] == self.model + assert call["credentials"] == self.credentials + + def test_on_new_chunk_defaults_tools_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["tools"] is None + + def test_on_new_chunk_defaults_stop_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stop"] is None + + def test_on_new_chunk_defaults_stream_true(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stream"] is True + + def test_on_new_chunk_defaults_user_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["user"] is None + + def test_on_new_chunk_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["END"] + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="chunk-user", + ) + call = self.cb.new_chunk_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "chunk-user" + + def test_on_new_chunk_called_multiple_times(self): + for i in range(5): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 5 + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for the on_after_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.result = MagicMock(spec=LLMResult) + self.model = "claude-3" + self.credentials = {"api_key": "anthropic-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 1.0} + + def test_on_after_invoke_called_with_required_args(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.after_invoke_calls) == 1 + call = self.cb.after_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["result"] is self.result + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_after_invoke_defaults_tools_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["tools"] is None + + def test_on_after_invoke_defaults_stop_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stop"] is None + + def test_on_after_invoke_defaults_stream_true(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stream"] is True + + def test_on_after_invoke_defaults_user_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["user"] is None + + def test_on_after_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["STOP"] + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="after-user", + ) + call = self.cb.after_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "after-user" + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for the on_invoke_error callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.ex = ValueError("something went wrong") + self.model = "gemini-pro" + self.credentials = {"api_key": "google-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"top_p": 0.9} + + def test_on_invoke_error_called_with_required_args(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 1 + call = self.cb.invoke_error_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["ex"] is self.ex + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_invoke_error_defaults_tools_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["tools"] is None + + def test_on_invoke_error_defaults_stop_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stop"] is None + + def test_on_invoke_error_defaults_stream_true(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stream"] is True + + def test_on_invoke_error_defaults_user_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["user"] is None + + def test_on_invoke_error_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["HALT"] + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="error-user", + ) + call = self.cb.invoke_error_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "error-user" + + def test_on_invoke_error_accepts_various_exception_types(self): + for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=exc, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 3 + + +# =========================================================================== +# Tests for print_text (concrete method on Callback) +# =========================================================================== + + +class TestPrintText: + """Tests for the concrete print_text method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + def test_print_text_without_color_prints_plain_text(self, capsys): + self.cb.print_text("hello world") + captured = capsys.readouterr() + assert captured.out == "hello world" + + def test_print_text_with_color_prints_colored_text(self, capsys): + self.cb.print_text("colored text", color="blue") + captured = capsys.readouterr() + # Should contain ANSI escape sequences + assert "colored text" in captured.out + assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out + + def test_print_text_without_color_no_ansi(self, capsys): + self.cb.print_text("plain text", color=None) + captured = capsys.readouterr() + assert captured.out == "plain text" + # No ANSI escape sequences + assert "\x1b" not in captured.out + + def test_print_text_default_end_is_empty_string(self, capsys): + self.cb.print_text("no newline") + captured = capsys.readouterr() + assert not captured.out.endswith("\n") + + def test_print_text_with_custom_end(self, capsys): + self.cb.print_text("with newline", end="\n") + captured = capsys.readouterr() + assert captured.out.endswith("\n") + + def test_print_text_with_empty_string(self, capsys): + self.cb.print_text("", color=None) + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_print_text_all_colors_work(self, color, capsys): + """Verify no KeyError is thrown for any valid color.""" + self.cb.print_text("test", color=color) + captured = capsys.readouterr() + assert "test" in captured.out + + def test_print_text_calls_get_colored_text_when_color_given(self): + with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: + with patch("builtins.print") as mock_print: + self.cb.print_text("hello", color="green") + mock_gct.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("[COLORED]", end="") + + def test_print_text_does_not_call_get_colored_text_when_no_color(self): + with patch.object(self.cb, "_get_colored_text") as mock_gct: + with patch("builtins.print"): + self.cb.print_text("hello", color=None) + mock_gct.assert_not_called() + + def test_print_text_passes_end_to_print(self): + with patch("builtins.print") as mock_print: + self.cb.print_text("text", end="---") + mock_print.assert_called_once_with("text", end="---") + + +# =========================================================================== +# Tests for _get_colored_text (private helper method) +# =========================================================================== + + +class TestGetColoredText: + """Tests for the _get_colored_text private method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) + def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): + result = self.cb._get_colored_text("text", color) + assert expected_code in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_contains_input_text(self, color): + result = self.cb._get_colored_text("hello", color) + assert "hello" in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_starts_with_escape(self, color): + result = self.cb._get_colored_text("text", color) + # Should start with an ANSI escape (\x1b or \u001b) + assert result.startswith("\x1b[") or result.startswith("\u001b[") + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_ends_with_reset(self, color): + result = self.cb._get_colored_text("text", color) + # Should end with the ANSI reset code + assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") + + def test_get_colored_text_returns_string(self): + result = self.cb._get_colored_text("text", "blue") + assert isinstance(result, str) + + def test_get_colored_text_blue_exact_format(self): + result = self.cb._get_colored_text("hello", "blue") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" + assert result == expected + + def test_get_colored_text_red_exact_format(self): + result = self.cb._get_colored_text("error", "red") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" + assert result == expected + + def test_get_colored_text_green_exact_format(self): + result = self.cb._get_colored_text("ok", "green") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" + assert result == expected + + def test_get_colored_text_yellow_exact_format(self): + result = self.cb._get_colored_text("warn", "yellow") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" + assert result == expected + + def test_get_colored_text_pink_exact_format(self): + result = self.cb._get_colored_text("info", "pink") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" + assert result == expected + + def test_get_colored_text_empty_string(self): + result = self.cb._get_colored_text("", "blue") + assert isinstance(result, str) + # Empty text should still have escape codes + assert _TEXT_COLOR_MAPPING["blue"] in result + + def test_get_colored_text_invalid_color_raises_key_error(self): + with pytest.raises(KeyError): + self.cb._get_colored_text("text", "purple") + + def test_get_colored_text_with_special_characters(self): + special = "hello\nworld\ttab" + result = self.cb._get_colored_text(special, "blue") + assert special in result + + def test_get_colored_text_with_long_text(self): + long_text = "a" * 10000 + result = self.cb._get_colored_text(long_text, "green") + assert long_text in result + + +# =========================================================================== +# Integration-style tests: full workflow through a ConcreteCallback +# =========================================================================== + + +class TestConcreteCallbackIntegration: + """End-to-end workflow tests using ConcreteCallback.""" + + def test_full_invocation_lifecycle(self): + """Simulate a complete LLM invocation lifecycle through all callbacks.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4o" + credentials = {"api_key": "sk-xyz"} + prompt_messages = [MagicMock(spec=PromptMessage)] + model_parameters = {"temperature": 0.5} + tools = [MagicMock(spec=PromptMessageTool)] + stop = [""] + user = "user-abc" + + # 1. Before invoke + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 2. Multiple chunks during streaming + for i in range(3): + chunk = MagicMock(spec=LLMResultChunk) + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 3. After invoke + result = MagicMock(spec=LLMResult) + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.new_chunk_calls) == 3 + assert len(cb.after_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 0 + + def test_error_lifecycle(self): + """Simulate an invoke that results in an error.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4" + credentials = {} + prompt_messages = [] + model_parameters = {} + + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + ex = RuntimeError("API timeout") + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 1 + assert cb.invoke_error_calls[0]["ex"] is ex + assert len(cb.after_invoke_calls) == 0 + + def test_print_text_with_color_in_integration(self, capsys): + """verify print_text works correctly in a concrete instance.""" + cb = ConcreteCallback() + cb.print_text("SUCCESS", color="green", end="\n") + captured = capsys.readouterr() + assert "SUCCESS" in captured.out + assert "\n" in captured.out + + def test_print_text_no_color_in_integration(self, capsys): + cb = ConcreteCallback() + cb.print_text("plain output") + captured = capsys.readouterr() + assert captured.out == "plain output" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py new file mode 100644 index 0000000000..0c6c1fd191 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py @@ -0,0 +1,700 @@ +""" +Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py + +Coverage targets: + - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, + prompt_message.name, model_parameters) + - LoggingCallback.on_new_chunk (writes to stdout) + - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) + - LoggingCallback.on_invoke_error (logs exception via logger.exception) +""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_usage() -> LLMUsage: + """Return a minimal LLMUsage instance.""" + return LLMUsage( + prompt_tokens=10, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.01"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.002"), + completion_price=Decimal("0.04"), + total_tokens=30, + total_price=Decimal("0.05"), + currency="USD", + latency=0.5, + ) + + +def _make_llm_result( + content: str = "hello world", + tool_calls: list | None = None, + model: str = "gpt-4", + system_fingerprint: str | None = "fp-abc", +) -> LLMResult: + """Return an LLMResult with an AssistantPromptMessage.""" + assistant_msg = AssistantPromptMessage( + content=content, + tool_calls=tool_calls or [], + ) + return LLMResult( + model=model, + message=assistant_msg, + usage=_make_usage(), + system_fingerprint=system_fingerprint, + ) + + +def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: + """Return a minimal LLMResultChunk.""" + return LLMResultChunk( + model="gpt-4", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content), + ), + ) + + +def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: + return UserPromptMessage(content=content, name=name) + + +def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: + return SystemPromptMessage(content=content) + + +def _make_tool(name: str = "my_tool") -> PromptMessageTool: + return PromptMessageTool(name=name, description="A tool", parameters={}) + + +def _make_tool_call( + call_id: str = "call-1", + func_name: str = "some_func", + arguments: str = '{"key": "value"}', +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), + ) + + +# --------------------------------------------------------------------------- +# Fixture: shared LoggingCallback instance (no heavy state) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cb() -> LoggingCallback: + return LoggingCallback() + + +@pytest.fixture +def llm_instance() -> MagicMock: + return MagicMock() + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for LoggingCallback.on_before_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + *, + model: str = "gpt-4", + credentials: dict | None = None, + prompt_messages: list | None = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, + stream: bool = True, + user: str | None = None, + ): + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials or {}, + prompt_messages=prompt_messages or [], + model_parameters=model_parameters or {}, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): + """Calling with bare-minimum args should not raise.""" + self._invoke(cb, llm_instance) + + def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The model name must appear in print_text calls.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model="claude-3") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "claude-3" in calls_text + + def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Each key-value pair of model_parameters must be printed.""" + params = {"temperature": 0.7, "max_tokens": 512} + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model_parameters=params) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "temperature" in calls_text + assert "0.7" in calls_text + assert "max_tokens" in calls_text + assert "512" in calls_text + + def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): + """Empty model_parameters dict should not raise.""" + self._invoke(cb, llm_instance, model_parameters={}) + + # ------------------------------------------------------------------ + # stop branch + # ------------------------------------------------------------------ + + def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """stop words must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=["STOP", "END"]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "stop" in calls_text + + def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=None the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=[] (falsy) the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + # ------------------------------------------------------------------ + # tools branch + # ------------------------------------------------------------------ + + def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """Tool names must appear in output when tools are provided.""" + tools = [_make_tool("search"), _make_tool("calculate")] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=tools) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "search" in calls_text + assert "calculate" in calls_text + + def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=None the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=[] (falsy) the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + # ------------------------------------------------------------------ + # user branch + # ------------------------------------------------------------------ + + def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """User string must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user="alice") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "alice" in calls_text + + def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When user=None the User line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "User:" not in calls_text + + # ------------------------------------------------------------------ + # stream branch + # ------------------------------------------------------------------ + + def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=True the [on_llm_new_chunk] marker must be printed.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=True) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" in calls_text + + def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=False) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" not in calls_text + + # ------------------------------------------------------------------ + # prompt_messages branch + # ------------------------------------------------------------------ + + def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has a name it must be printed.""" + msg = _make_user_prompt("hi", name="bob") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "bob" in calls_text + + def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has no name the name line must NOT appear.""" + msg = _make_user_prompt("hi", name=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tname:" not in calls_text + + def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Role and content of each PromptMessage must appear in output.""" + msg = _make_system_prompt("Be concise.") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "system" in calls_text + assert "Be concise." in calls_text + + def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All entries in prompt_messages are iterated and printed.""" + msgs = [ + _make_system_prompt("sys"), + _make_user_prompt("user msg"), + ] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=msgs) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "sys" in calls_text + assert "user msg" in calls_text + + # ------------------------------------------------------------------ + # Combination: everything provided + # ------------------------------------------------------------------ + + def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): + """Supply stop, tools, user, multiple params, named message – no exception.""" + msgs = [_make_user_prompt("question", name="alice")] + tools = [_make_tool("tool_a")] + with patch.object(cb, "print_text"): + self._invoke( + cb, + llm_instance, + model="gpt-3.5", + model_parameters={"temperature": 1.0, "top_p": 0.9}, + tools=tools, + stop=["DONE"], + stream=True, + user="alice", + prompt_messages=msgs, + ) + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for LoggingCallback.on_new_chunk.""" + + def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): + """on_new_chunk must write the chunk's text content to sys.stdout.""" + chunk = _make_chunk("hello from LLM") + written = [] + + with patch("sys.stdout") as mock_stdout: + mock_stdout.write.side_effect = written.append + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("hello from LLM") + mock_stdout.flush.assert_called_once() + + def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works correctly even when the chunk content is an empty string.""" + chunk = _make_chunk("") + with patch("sys.stdout") as mock_stdout: + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("") + mock_stdout.flush.assert_called_once() + + def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters are accepted without errors.""" + chunk = _make_chunk("data") + with patch("sys.stdout"): + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.5}, + tools=[_make_tool("t1")], + stop=["EOS"], + stream=True, + user="bob", + ) + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for LoggingCallback.on_after_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + result: LLMResult, + **kwargs, + ): + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """After-invoke header, content, model, usage, fingerprint must be printed.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_after_invoke]" in calls_text + assert "hello world" in calls_text + assert "gpt-4" in calls_text + assert "fp-abc" in calls_text + + def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): + """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" + result = _make_llm_result(tool_calls=[]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" not in calls_text + + def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tool_calls exist their id, name, and JSON arguments must be printed.""" + tc = _make_tool_call( + call_id="call-xyz", + func_name="fetch_data", + arguments='{"url": "https://example.com"}', + ) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" in calls_text + assert "call-xyz" in calls_text + assert "fetch_data" in calls_text + # arguments should be JSON-dumped + assert "https://example.com" in calls_text + + def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All tool calls in the list must be iterated.""" + tcs = [ + _make_tool_call("id-1", "func_a", '{"a": 1}'), + _make_tool_call("id-2", "func_b", '{"b": 2}'), + ] + result = _make_llm_result(tool_calls=tcs) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "id-1" in calls_text + assert "func_a" in calls_text + assert "id-2" in calls_text + assert "func_b" in calls_text + + def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When system_fingerprint is None it should still be printed (as None).""" + result = _make_llm_result(system_fingerprint=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "System Fingerprint: None" in calls_text + + def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The usage object must appear in the printed output.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Usage:" in calls_text + + def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): + """Verify json.dumps is applied to the arguments field (a string).""" + raw_args = '{"x": 42}' + tc = _make_tool_call(arguments=raw_args) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + + # Check if any call to print_text included the expected (json-encoded) arguments + # json.dumps(raw_args) produces a string starting and ending with quotes + expected_substring = json.dumps(raw_args) + found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) + assert found, f"Expected {expected_substring} to be printed in one of the calls" + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + result = _make_llm_result() + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.9}, + tools=[_make_tool("t")], + stop=[""], + stream=False, + user="carol", + ) + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for LoggingCallback.on_invoke_error.""" + + def _invoke_error( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + ex: Exception, + **kwargs, + ): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """The [on_llm_invoke_error] banner must be printed.""" + with patch.object(cb, "print_text") as mock_print: + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, RuntimeError("boom")) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_invoke_error]" in calls_text + + def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): + """logger.exception must be called with the exception.""" + ex = ValueError("something went wrong") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works with any exception type (TypeError, IOError, etc.).""" + for exc_cls in (TypeError, IOError, KeyError, Exception): + ex = exc_cls("error") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + ex = RuntimeError("fail") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger"): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.7}, + tools=[_make_tool("t")], + stop=["STOP"], + stream=True, + user="dave", + ) + + +# =========================================================================== +# Tests for print_text (inherited from Callback, exercised through LoggingCallback) +# =========================================================================== + + +class TestPrintText: + """Verify that print_text from the Callback base class works correctly.""" + + def test_print_text_with_color(self, cb: LoggingCallback, capsys): + """print_text with a known colour should emit an ANSI escape sequence.""" + cb.print_text("hello", color="blue") + captured = capsys.readouterr() + assert "hello" in captured.out + # ANSI escape codes should be present + assert "\x1b[" in captured.out + + def test_print_text_without_color(self, cb: LoggingCallback, capsys): + """print_text without colour should print plain text.""" + cb.print_text("plain text") + captured = capsys.readouterr() + assert "plain text" in captured.out + + def test_print_text_all_colours(self, cb: LoggingCallback, capsys): + """Verify all supported colour keys don't raise.""" + for colour in ("blue", "yellow", "pink", "green", "red"): + cb.print_text("x", color=colour) + captured = capsys.readouterr() + # All outputs should contain 'x' (5 calls) + assert captured.out.count("x") >= 5 + + +# =========================================================================== +# Integration-style test: real print_text called (no mocking) +# =========================================================================== + + +class TestLoggingCallbackIntegration: + """Light integration tests – real print_text calls, just checking no exceptions.""" + + def test_on_before_invoke_full_run(self, capsys): + """Full on_before_invoke run with all optional fields – verifies real output.""" + cb = LoggingCallback() + llm = MagicMock() + msgs = [_make_user_prompt("Who are you?", name="tester")] + tools = [_make_tool("calculator")] + cb.on_before_invoke( + llm_instance=llm, + model="gpt-4-turbo", + credentials={"api_key": "sk-xxx"}, + prompt_messages=msgs, + model_parameters={"temperature": 0.8}, + tools=tools, + stop=["STOP"], + stream=True, + user="test_user", + ) + captured = capsys.readouterr() + assert "gpt-4-turbo" in captured.out + assert "calculator" in captured.out + assert "test_user" in captured.out + assert "STOP" in captured.out + assert "tester" in captured.out + + def test_on_new_chunk_full_run(self, capsys): + """Full on_new_chunk run – verifies real stdout write.""" + cb = LoggingCallback() + chunk = _make_chunk("streaming token") + cb.on_new_chunk( + llm_instance=MagicMock(), + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "streaming token" in captured.out + + def test_on_after_invoke_full_run_with_tool_calls(self, capsys): + """Full on_after_invoke run with tool calls – verifies real output.""" + cb = LoggingCallback() + tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') + result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") + cb.on_after_invoke( + llm_instance=MagicMock(), + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "result content" in captured.out + assert "call-99" in captured.out + assert "do_thing" in captured.out + assert "fp-xyz" in captured.out + + def test_on_invoke_error_full_run(self, capsys): + """Full on_invoke_error run – just verifies no exception is raised.""" + cb = LoggingCallback() + ex = RuntimeError("something bad happened") + # logger.exception writes to stderr; we just confirm it doesn't crash + cb.on_invoke_error( + llm_instance=MagicMock(), + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py new file mode 100644 index 0000000000..db147fb0cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py @@ -0,0 +1,35 @@ +from dify_graph.model_runtime.entities.common_entities import I18nObject + + +class TestI18nObject: + def test_i18n_object_with_both_languages(self): + """ + Test I18nObject when both zh_Hans and en_US are provided. + """ + i18n = I18nObject(zh_Hans="你好", en_US="Hello") + assert i18n.zh_Hans == "你好" + assert i18n.en_US == "Hello" + + def test_i18n_object_fallback_to_en_us(self): + """ + Test I18nObject when zh_Hans is missing, it should fallback to en_US. + """ + i18n = I18nObject(en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_none_zh_hans(self): + """ + Test I18nObject when zh_Hans is None, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans=None, en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_empty_zh_hans(self): + """ + Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans="", en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py rename to api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py new file mode 100644 index 0000000000..a96a38f5cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py @@ -0,0 +1,210 @@ +import pytest + +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, + VideoPromptMessageContent, +) + + +class TestPromptMessageRole: + def test_value_of(self): + assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM + assert PromptMessageRole.value_of("user") == PromptMessageRole.USER + assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT + assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL + + with pytest.raises(ValueError, match="invalid prompt message type value invalid"): + PromptMessageRole.value_of("invalid") + + +class TestPromptMessageEntities: + def test_prompt_message_tool(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + assert tool.name == "test_tool" + assert tool.description == "test desc" + assert tool.parameters == {"foo": "bar"} + + def test_prompt_message_function(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + func = PromptMessageFunction(function=tool) + assert func.type == "function" + assert func.function == tool + + +class TestPromptMessageContent: + def test_text_content(self): + content = TextPromptMessageContent(data="hello") + assert content.type == PromptMessageContentType.TEXT + assert content.data == "hello" + + def test_image_content(self): + content = ImagePromptMessageContent( + format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH + ) + assert content.type == PromptMessageContentType.IMAGE + assert content.detail == ImagePromptMessageContent.DETAIL.HIGH + assert content.data == "data:image/jpeg;base64,abc" + + def test_image_content_url(self): + content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") + assert content.data == "https://example.com/image.jpg" + + def test_audio_content(self): + content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") + assert content.type == PromptMessageContentType.AUDIO + assert content.data == "data:audio/mpeg;base64,abc" + + def test_video_content(self): + content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") + assert content.type == PromptMessageContentType.VIDEO + assert content.data == "data:video/mp4;base64,abc" + + def test_document_content(self): + content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") + assert content.type == PromptMessageContentType.DOCUMENT + assert content.data == "data:application/pdf;base64,abc" + + +class TestPromptMessages: + def test_user_prompt_message(self): + msg = UserPromptMessage(content="hello") + assert msg.role == PromptMessageRole.USER + assert msg.content == "hello" + assert msg.is_empty() is False + assert msg.get_text_content() == "hello" + + def test_user_prompt_message_complex_content(self): + content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] + msg = UserPromptMessage(content=content) + assert msg.get_text_content() == "hello world" + + # Test validation from dict + msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) + assert isinstance(msg2.content[0], TextPromptMessageContent) + assert msg2.content[0].data == "hi" + + def test_prompt_message_empty(self): + msg = UserPromptMessage(content=None) + assert msg.is_empty() is True + assert msg.get_text_content() == "" + + def test_assistant_prompt_message(self): + msg = AssistantPromptMessage(content="thinking...") + assert msg.role == PromptMessageRole.ASSISTANT + assert msg.is_empty() is False + + tool_call = AssistantPromptMessage.ToolCall( + id="call_1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) + assert msg_with_tools.is_empty() is False + assert msg_with_tools.role == PromptMessageRole.ASSISTANT + + def test_assistant_tool_call_id_transform(self): + tool_call = AssistantPromptMessage.ToolCall( + id=123, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + assert tool_call.id == "123" + + def test_system_prompt_message(self): + msg = SystemPromptMessage(content="you are a bot") + assert msg.role == PromptMessageRole.SYSTEM + assert msg.content == "you are a bot" + + def test_tool_prompt_message(self): + # Case 1: Both content and tool_call_id are present + msg = ToolPromptMessage(content="result", tool_call_id="call_1") + assert msg.role == PromptMessageRole.TOOL + assert msg.tool_call_id == "call_1" + assert msg.is_empty() is False + + # Case 2: Content is present, but tool_call_id is empty + msg_content_only = ToolPromptMessage(content="result", tool_call_id="") + assert msg_content_only.is_empty() is False + + # Case 3: Content is None, but tool_call_id is present + msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") + assert msg_id_only.is_empty() is False + + # Case 4: Both content and tool_call_id are empty + msg_empty = ToolPromptMessage(content=None, tool_call_id="") + assert msg_empty.is_empty() is True + + def test_prompt_message_validation_errors(self): + with pytest.raises(KeyError): + # Invalid content type in list + UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) + + with pytest.raises(ValueError, match="invalid prompt message"): + # Not a dict or PromptMessageContent + UserPromptMessage(content=[123]) + + def test_prompt_message_serialization(self): + # Case: content is None + assert UserPromptMessage(content=None).serialize_content(None) is None + + # Case: content is str + assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" + + # Case: content is list of dict + content_list = [{"type": "text", "data": "hi"}] + msg = UserPromptMessage(content=content_list) + assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] + + # Case: content is Sequence but not list (e.g. tuple) + # To hit line 204, we can call serialize_content manually or + # try to pass a type that pydantic doesn't convert to list in its internal state. + # Actually, let's just call it manually on the instance. + msg = UserPromptMessage(content="test") + content_tuple = (TextPromptMessageContent(data="hi"),) + assert msg.serialize_content(content_tuple) == content_tuple + + def test_prompt_message_mixed_content_validation(self): + # Test branch: isinstance(prompt, PromptMessageContent) + # but not (TextPromptMessageContent | MultiModalPromptMessageContent) + # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) + + # We need a PromptMessageContent that is NOT Text or MultiModal. + # But PromptMessageContentUnionTypes discriminator handles this usually. + # We can bypass high-level validation by passing the object directly in a list. + + class MockContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.TEXT + data: str + + mock_item = MockContent(data="test") + msg = UserPromptMessage(content=[mock_item]) + # It should hit line 187 and convert to TextPromptMessageContent + assert isinstance(msg.content[0], TextPromptMessageContent) + assert msg.content[0].data == "test" + + def test_prompt_message_get_text_content_branches(self): + # content is None + msg_none = UserPromptMessage(content=None) + assert msg_none.get_text_content() == "" + + # content is list but no text content + image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") + msg_image = UserPromptMessage(content=[image]) + assert msg_image.get_text_content() == "" + + # content is list with mixed + text = TextPromptMessageContent(data="hello") + msg_mixed = UserPromptMessage(content=[text, image]) + assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py new file mode 100644 index 0000000000..3d03361f2a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py @@ -0,0 +1,220 @@ +from decimal import Decimal + +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ModelUsage, + ParameterRule, + ParameterType, + PriceConfig, + PriceInfo, + PriceType, + ProviderModel, +) + + +class TestModelType: + def test_value_of(self): + assert ModelType.value_of("text-generation") == ModelType.LLM + assert ModelType.value_of(ModelType.LLM) == ModelType.LLM + assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING + assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING + assert ModelType.value_of("reranking") == ModelType.RERANK + assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK + assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT + assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT + assert ModelType.value_of("tts") == ModelType.TTS + assert ModelType.value_of(ModelType.TTS) == ModelType.TTS + assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION + + with pytest.raises(ValueError, match="invalid origin model type invalid"): + ModelType.value_of("invalid") + + def test_to_origin_model_type(self): + assert ModelType.LLM.to_origin_model_type() == "text-generation" + assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" + assert ModelType.RERANK.to_origin_model_type() == "reranking" + assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" + assert ModelType.TTS.to_origin_model_type() == "tts" + assert ModelType.MODERATION.to_origin_model_type() == "moderation" + + # Testing the else branch in to_origin_model_type + # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. + # But if we look at the implementation: + # if self == self.LLM: ... elif ... else: raise ValueError + # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. + # Actually, adding a new member to an enum at runtime is possible but messy. + # Let's see if we can trigger it. + + +class TestFetchFrom: + def test_values(self): + assert FetchFrom.PREDEFINED_MODEL == "predefined-model" + assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" + + +class TestModelFeature: + def test_values(self): + assert ModelFeature.TOOL_CALL == "tool-call" + assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" + assert ModelFeature.AGENT_THOUGHT == "agent-thought" + assert ModelFeature.VISION == "vision" + assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" + assert ModelFeature.DOCUMENT == "document" + assert ModelFeature.VIDEO == "video" + assert ModelFeature.AUDIO == "audio" + assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" + + +class TestDefaultParameterName: + def test_value_of(self): + assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE + assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P + + with pytest.raises(ValueError, match="invalid parameter name invalid"): + DefaultParameterName.value_of("invalid") + + +class TestParameterType: + def test_values(self): + assert ParameterType.FLOAT == "float" + assert ParameterType.INT == "int" + assert ParameterType.STRING == "string" + assert ParameterType.BOOLEAN == "boolean" + assert ParameterType.TEXT == "text" + + +class TestModelPropertyKey: + def test_values(self): + assert ModelPropertyKey.MODE == "mode" + assert ModelPropertyKey.CONTEXT_SIZE == "context_size" + + +class TestProviderModel: + def test_provider_model(self): + model = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model.model == "gpt-4" + assert model.support_structure_output is False + + model_with_features = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.STRUCTURED_OUTPUT], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model_with_features.support_structure_output is True + + +class TestParameterRule: + def test_parameter_rule(self): + rule = ParameterRule( + name="temperature", + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=0.7, + min=0.0, + max=1.0, + precision=2, + ) + assert rule.name == "temperature" + assert rule.default == 0.7 + + +class TestPriceConfig: + def test_price_config(self): + config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") + assert config.input == Decimal("0.01") + assert config.output == Decimal("0.02") + + +class TestAIModelEntity: + def test_ai_model_entity_no_json_schema(self): + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) + + def test_ai_model_entity_with_json_schema(self): + # Case: json_schema in parameter rules, features is None + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_features_empty(self): + # Case: json_schema in parameter rules, features is empty list + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_other_features(self): + # Case: json_schema in parameter rules, features has other things + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.VISION], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + assert ModelFeature.VISION in entity.features + + +class TestModelUsage: + def test_model_usage(self): + usage = ModelUsage() + assert isinstance(usage, ModelUsage) + + +class TestPriceType: + def test_values(self): + assert PriceType.INPUT == "input" + assert PriceType.OUTPUT == "output" + + +class TestPriceInfo: + def test_price_info(self): + info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") + assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py new file mode 100644 index 0000000000..af62b2a84c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py @@ -0,0 +1,63 @@ +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class TestInvokeErrors: + def test_invoke_error_with_description(self): + error = InvokeError("Custom description") + assert error.description == "Custom description" + assert str(error) == "Custom description" + assert isinstance(error, ValueError) + + def test_invoke_error_without_description(self): + error = InvokeError() + assert error.description is None + assert str(error) == "InvokeError" + + def test_invoke_connection_error(self): + # Now preserves class-level description + error = InvokeConnectionError() + assert error.description == "Connection Error" + assert str(error) == "Connection Error" + assert isinstance(error, InvokeError) + + # Test with explicit description + error_with_desc = InvokeConnectionError("Connection Error") + assert error_with_desc.description == "Connection Error" + assert str(error_with_desc) == "Connection Error" + + def test_invoke_server_unavailable_error(self): + error = InvokeServerUnavailableError() + assert error.description == "Server Unavailable Error" + assert str(error) == "Server Unavailable Error" + assert isinstance(error, InvokeError) + + def test_invoke_rate_limit_error(self): + error = InvokeRateLimitError() + assert error.description == "Rate Limit Error" + assert str(error) == "Rate Limit Error" + assert isinstance(error, InvokeError) + + def test_invoke_authorization_error(self): + error = InvokeAuthorizationError() + assert error.description == "Incorrect model credentials provided, please check and try again. " + assert str(error) == "Incorrect model credentials provided, please check and try again. " + assert isinstance(error, InvokeError) + + def test_invoke_bad_request_error(self): + error = InvokeBadRequestError() + assert error.description == "Bad Request Error" + assert str(error) == "Bad Request Error" + assert isinstance(error, InvokeError) + + def test_invoke_error_inheritance(self): + # Test that we can override the default description in subclasses + error = InvokeBadRequestError("Overridden Error") + assert error.description == "Overridden Error" + assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py new file mode 100644 index 0000000000..382dce876e --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py @@ -0,0 +1,336 @@ +import decimal +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, + PriceType, +) +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel + + +class TestAIModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def ai_model(self, mock_plugin_model_provider): + return AIModel( + tenant_id="tenant_123", + model_type=ModelType.LLM, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_invoke_error_mapping(self, ai_model): + mapping = ai_model._invoke_error_mapping + assert InvokeConnectionError in mapping + assert InvokeServerUnavailableError in mapping + assert InvokeRateLimitError in mapping + assert InvokeAuthorizationError in mapping + assert InvokeBadRequestError in mapping + assert PluginDaemonInnerError in mapping + assert ValueError in mapping + + def test_transform_invoke_error(self, ai_model): + # Case: mapped error (InvokeAuthorizationError) + err = Exception("Original error") + with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeAuthorizationError) + assert "Incorrect model credentials provided" in str(transformed.description) + + # Case: mapped error (InvokeError subclass) + with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeError) + assert "[test_provider]" in transformed.description + + # Case: mapped error (not InvokeError) + class CustomNonInvokeError(Exception): + pass + + with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert transformed == err + + # Case: unmapped error + unmapped_err = Exception("Unmapped") + transformed = ai_model._transform_invoke_error(unmapped_err) + assert isinstance(transformed, InvokeError) + assert "Error: Unmapped" in transformed.description + + def test_get_price(self, ai_model): + model_name = "test_model" + credentials = {"key": "value"} + + # Mock get_model_schema + mock_schema = MagicMock(spec=AIModelEntity) + mock_schema.pricing = PriceConfig( + input=decimal.Decimal("0.002"), + output=decimal.Decimal("0.004"), + unit=decimal.Decimal(1000), # 1000 tokens per unit + currency="USD", + ) + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + # Test INPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.002") + + # Test OUTPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.004") + + # Case: unit_price is None (returns zeroed PriceInfo) + mock_schema.pricing = None + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) + assert price_info.total_amount == decimal.Decimal("0.0") + + def test_get_price_no_price_config_error(self, ai_model): + model_name = "test_model" + + # We need it to be truthy at line 107 and 112 but falsy at line 127. + class ChangingPriceConfig: + def __init__(self): + self.input = decimal.Decimal("0.01") + self.unit = decimal.Decimal(1) + self.currency = "USD" + self.called = 0 + + def __bool__(self): + self.called += 1 + return self.called <= 2 + + mock_schema = MagicMock() + mock_schema.pricing = ChangingPriceConfig() + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + with pytest.raises(ValueError) as excinfo: + ai_model.get_price(model_name, {}, PriceType.INPUT, 1000) + assert "Price config not found" in str(excinfo.value) + + def test_get_model_schema_cache_hit(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis: + mock_redis.get.return_value = mock_schema.model_dump_json().encode() + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema.model == "test_model" + mock_redis.get.assert_called_once() + + def test_get_model_schema_cache_miss(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema == mock_schema + mock_manager.get_model_schema.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_get_model_schema_redis_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.side_effect = RedisError("Connection refused") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_manager.get_model_schema.assert_called_once() + + def test_get_model_schema_validation_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b"invalid json" + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + # This should trigger ValidationError at line 166 and go to delete() + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_delete_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b'{"invalid": "schema"}' + mock_redis.delete.side_effect = RedisError("Delete failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_setex_error(self, ai_model): + model_name = "test_model" + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RuntimeError("Setex failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema == mock_schema + mock_redis.setex.assert_called() + + def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="invalid", + use_template="invalid_template_name", + label=I18nObject(en_US="Invalid"), + type=ParameterType.FLOAT, + ) + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + assert schema.parameter_rules[0].use_template == "invalid_template_name" + + def test_get_customizable_model_schema_from_credentials(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + help=I18nObject(en_US=""), + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + help=I18nObject(en_US="", zh_Hans=""), + ), + ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + + assert schema.parameter_rules[0].max == 1.0 + assert schema.parameter_rules[1].help.en_US != "" + assert schema.parameter_rules[2].help.zh_Hans != "" + assert schema.parameter_rules[3].use_template is None + + def test_get_customizable_model_schema_from_credentials_none(self, ai_model): + with patch.object(AIModel, "get_customizable_model_schema", return_value=None): + schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) + assert schema is None + + def test_get_customizable_model_schema_default(self, ai_model): + assert ai_model.get_customizable_model_schema("model", {}) is None + + def test_get_default_parameter_rule_variable_map(self, ai_model): + # Valid + res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) + assert res["default"] == 0.0 + + # Invalid + with pytest.raises(Exception) as excinfo: + ai_model._get_default_parameter_rule_variable_map("invalid_name") + assert "Invalid model parameter rule name" in str(excinfo.value) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py new file mode 100644 index 0000000000..a692f8023a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py @@ -0,0 +1,476 @@ +import logging +from collections.abc import Generator, Iterator, Sequence +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module + +# Access large_language_model members via llm_module to avoid partial import issues in CI +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelType, PriceInfo +from dify_graph.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks + + +def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: + return LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1), + prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), + completion_tokens=completion_tokens, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1), + completion_price=Decimal(completion_tokens) * Decimal("0.002"), + total_tokens=prompt_tokens + completion_tokens, + total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), + currency="USD", + latency=0.0, + ) + + +def _tool_call_delta( + *, + tool_call_id: str, + tool_type: str = "function", + function_name: str = "", + function_arguments: str = "", +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=tool_call_id, + type=tool_type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), + ) + + +def _chunk( + *, + model: str = "test-model", + content: str | list[Any] | None = None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + return LLMResultChunk( + model=model, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), + usage=usage, + ), + ) + + +@dataclass +class SpyCallback(Callback): + raise_error: bool = False + before: list[dict[str, Any]] = field(default_factory=list) + new_chunk: list[dict[str, Any]] = field(default_factory=list) + after: list[dict[str, Any]] = field(default_factory=list) + error: list[dict[str, Any]] = field(default_factory=list) + + def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.before.append(kwargs) + + def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] + self.new_chunk.append(kwargs) + + def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.after.append(kwargs) + + def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] + self.error.append(kwargs) + + +class _TestLLM(llm_module.LargeLanguageModel): + def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] + return PriceInfo( + unit_price=Decimal("0.01"), + unit=Decimal(1), + total_amount=Decimal(tokens) * Decimal("0.01"), + currency="USD", + ) + + def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] + return RuntimeError(f"transformed: {error}") + + +@pytest.fixture +def llm() -> _TestLLM: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return _TestLLM.model_construct( + tenant_id="tenant", + model_type=ModelType.LLM, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + started_at=1.0, + ) + + +def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) + assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" + + +def test_run_callbacks_no_callbacks_noop() -> None: + invoked: list[int] = [] + llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) + llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) + assert invoked == [] + + +def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: + class Boom: + raise_error = False + + caplog.set_level(logging.WARNING) + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) + + +def test_run_callbacks_reraises_when_raise_error_true() -> None: + class Boom: + raise_error = True + + with pytest.raises(ValueError, match="boom"): + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + + +def test_get_or_create_tool_call_empty_id_returns_last() -> None: + calls = [ + _tool_call_delta(tool_call_id="id1", function_name="a"), + _tool_call_delta(tool_call_id="id2", function_name="b"), + ] + assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] + + +def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: + with pytest.raises(ValueError, match="tool_call_id is empty"): + llm_module._get_or_create_tool_call([], "") + + +def test_get_or_create_tool_call_creates_if_missing() -> None: + calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call = llm_module._get_or_create_tool_call(calls, "new-id") + assert tool_call.id == "new-id" + assert tool_call.function.name == "" + assert tool_call.function.arguments == "" + assert calls == [tool_call] + + +def test_get_or_create_tool_call_returns_existing_when_found() -> None: + existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") + calls = [existing] + assert llm_module._get_or_create_tool_call(calls, "same-id") is existing + + +def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: + tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") + delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") + llm_module._merge_tool_call_delta(tool_call, delta) + assert tool_call.id == "id2" + assert tool_call.type == "function" + assert tool_call.function.name == "y" + assert tool_call.function.arguments == "{}" + + +def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) + delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call([delta], existing) + assert len(existing) == 1 + assert existing[0].id == "chatcmpl-tool-fixed" + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{" + + +def test_increase_tool_call_merges_incremental_arguments() -> None: + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing + ) + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing + ) + assert len(existing) == 1 + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{}" + + +@pytest.mark.parametrize( + ("content", "expected_type"), + [ + ("hello", str), + ([TextPromptMessageContent(data="hello")], list), + ], +) +def test_build_llm_result_from_chunks_accumulates_and_raises_error( + content: str | list[TextPromptMessageContent], + expected_type: type, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) + caplog.set_level(logging.DEBUG) + + tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") + first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") + + def iter_with_error() -> Iterator[LLMResultChunk]: + yield first + raise RuntimeError("drain boom") + + with pytest.raises(RuntimeError, match="drain boom"): + _build_llm_result_from_chunks( + model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() + ) + + assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) + + +def test_build_llm_result_from_chunks_empty_iterator() -> None: + def empty() -> Iterator[LLMResultChunk]: + if False: # pragma: no cover + yield _chunk() + return + + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) + assert result.message.content == [] + assert result.usage.total_tokens == 0 + assert result.system_fingerprint is None + + +def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: + chunks = iter([_chunk(content="first"), _chunk(content="second")]) + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) + assert result.message.content == "firstsecond" + + +def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None: + invoked: dict[str, Any] = {} + + class FakePluginModelClient: + def invoke_llm(self, **kwargs: Any) -> str: + invoked.update(kwargs) + return "ok" + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + + prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) + result = llm_module._invoke_llm_via_plugin( + tenant_id="t", + user_id="u", + plugin_id="p", + provider="prov", + model="m", + credentials={"k": "v"}, + model_parameters={"temp": 1}, + prompt_messages=prompt_messages, + tools=None, + stop=("a", "b"), + stream=True, + ) + + assert result == "ok" + assert invoked["prompt_messages"] == list(prompt_messages) + assert invoked["stop"] == ["a", "b"] + + +def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None: + llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + assert ( + llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result + ) + + +def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None: + chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) + result = llm_module._normalize_non_stream_plugin_result( + model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks + ) + assert isinstance(result, LLMResult) + assert result.message.content == "hello" + + +def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_result, + ) + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) + assert isinstance(result, LLMResult) + assert result.prompt_messages == prompt_messages + assert len(cb.before) == 1 + assert len(cb.after) == 1 + assert cb.after[0]["result"].prompt_messages == prompt_messages + + +def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_chunks = iter( + [ + _chunk(model="m1", content="a"), + _chunk( + model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" + ), + _chunk(model="m3", content=None), + ] + ) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_chunks, + ) + + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) + + assert isinstance(gen, Generator) + chunks = list(gen) + assert len(chunks) == 3 + assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) + assert len(cb.before) == 1 + assert len(cb.new_chunk) == 3 + assert len(cb.after) == 1 + final_result: LLMResult = cb.after[0]["result"] + assert final_result.model == "m3" + assert final_result.system_fingerprint == "fp" + assert isinstance(final_result.message.content, list) + assert [c.data for c in final_result.message.content] == ["a", "b"] + assert final_result.usage.total_tokens == 5 + + +def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + def boom(**_: Any) -> Any: + raise ValueError("plugin down") + + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom + ) + cb = SpyCallback() + with pytest.raises(RuntimeError, match="transformed: plugin down"): + llm.invoke( + model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] + ) + assert len(cb.error) == 1 + assert isinstance(cb.error[0]["ex"], ValueError) + + +def test_invoke_raises_not_implemented_for_unsupported_result_type( + llm: _TestLLM, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result") + monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result") + with pytest.raises(NotImplementedError, match="unsupported invoke result type"): + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + + +def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + captured_callbacks: list[list[Callback]] = [] + + class FakeLoggingCallback(SpyCallback): + pass + + monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) + monkeypatch.setattr(llm_module.dify_config, "DEBUG", True) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), + ) + + original_trigger = llm._trigger_before_invoke_callbacks + + def spy_trigger(*args: Any, **kwargs: Any) -> None: + captured_callbacks.append(list(kwargs["callbacks"])) + original_trigger(*args, **kwargs) + + monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) + + +def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 + + +def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True) + + class FakePluginModelClient: + def get_llm_num_tokens(self, **kwargs: Any) -> int: + assert kwargs["tenant_id"] == "tenant" + assert kwargs["plugin_id"] == "plugin-id" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "llm" + return 42 + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 + + +def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) + llm.started_at = 1.0 + usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) + assert usage.total_tokens == 15 + assert usage.total_price == Decimal("0.15") + assert usage.latency == 3.5 + + +def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: + def broken() -> Iterator[LLMResultChunk]: + yield _chunk(content="ok") + raise ValueError("chunk stream broken") + + gen = llm._invoke_result_generator( + model="m", + result=broken(), + credentials={}, + prompt_messages=[UserPromptMessage(content="u")], + model_parameters={}, + callbacks=[SpyCallback()], + ) + + with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): + list(gen) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py new file mode 100644 index 0000000000..6ccc44ceb8 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel + + +class TestModerationModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def moderation_model(self, mock_plugin_model_provider): + return ModerationModel( + tenant_id="tenant_123", + model_type=ModelType.MODERATION, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, moderation_model): + assert moderation_model.model_type == ModelType.MODERATION + + def test_invoke_success(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + user = "user_123" + + with ( + patch("core.plugin.impl.model.PluginModelClient") as mock_client_class, + patch("time.perf_counter", return_value=1.0), + ): + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = True + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user) + + assert result is True + assert moderation_model.started_at == 1.0 + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_success_no_user(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = False + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert result is False + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_exception(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py new file mode 100644 index 0000000000..67828894b3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py @@ -0,0 +1,181 @@ +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel + + +@pytest.fixture +def rerank_model() -> RerankModel: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return RerankModel.model_construct( + tenant_id="tenant", + model_type=ModelType.RERANK, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + + +def test_model_type_is_rerank_by_default() -> None: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + model = RerankModel( + tenant_id="tenant", + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + assert model.model_type == ModelType.RERANK + + +def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_rerank_called_with: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + result = rerank_model.invoke( + model="rerank", + credentials={"k": "v"}, + query="q", + docs=["d1", "d2"], + score_threshold=0.2, + top_n=10, + user="user-1", + ) + + assert result == expected + assert fake_client.invoke_rerank_called_with == { + "tenant_id": "tenant", + "user_id": "user-1", + "plugin_id": "plugin-id", + "provider": "provider", + "model": "rerank", + "credentials": {"k": "v"}, + "query": "q", + "docs": ["d1", "d2"], + "score_threshold": 0.2, + "top_n": 10, + } + + +def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + class FakePluginModelClient: + def __init__(self) -> None: + self.kwargs: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.kwargs = kwargs + return RerankResult(model="m", docs=[]) + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + assert fake_client.kwargs is not None + assert fake_client.kwargs["user_id"] == "unknown" + + +def test_invoke_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + + +def test_invoke_multimodal_calls_plugin_and_passes_args( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None + + def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_multimodal_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + query = {"type": "text", "text": "q"} + docs = [{"type": "text", "text": "d1"}] + result = rerank_model.invoke_multimodal_rerank( + model="mm", + credentials={"k": "v"}, + query=query, + docs=docs, + score_threshold=None, + top_n=None, + user=None, + ) + + assert result == expected + assert fake_client.invoke_multimodal_rerank_called_with is not None + assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant" + assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown" + assert fake_client.invoke_multimodal_rerank_called_with["query"] == query + assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs + + +def test_invoke_multimodal_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_multimodal_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}]) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py new file mode 100644 index 0000000000..f891718dc6 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py @@ -0,0 +1,87 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class TestSpeech2TextModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def speech2text_model(self, mock_plugin_model_provider): + return Speech2TextModel( + tenant_id="tenant_123", + model_type=ModelType.SPEECH2TEXT, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, speech2text_model): + assert speech2text_model.model_type == ModelType.SPEECH2TEXT + + def test_invoke_success(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_success_no_user(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_exception(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py new file mode 100644 index 0000000000..c8f0a2ad49 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py @@ -0,0 +1,185 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.embedding_type import EmbeddingInputType +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class TestTextEmbeddingModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def text_embedding_model(self, mock_plugin_model_provider): + return TextEmbeddingModel( + tenant_id="tenant_123", + model_type=ModelType.TEXT_EMBEDDING, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, text_embedding_model): + assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING + + def test_invoke_with_texts(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + user = "user_123" + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_with_multimodel_documents(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + multimodel_documents = [{"type": "text", "text": "hello"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_multimodal_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_multimodal_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + documents=multimodel_documents, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_no_input(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + with pytest.raises(ValueError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials) + + assert "No texts or files provided" in str(excinfo.value) + + def test_invoke_precedence(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + multimodel_documents = [{"type": "text", "text": "world"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once() + mock_client.invoke_multimodal_embedding.assert_not_called() + + def test_invoke_exception(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_num_tokens(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + expected_tokens = [1, 1] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_text_embedding_num_tokens.return_value = expected_tokens + + result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts) + + assert result == expected_tokens + mock_client.get_text_embedding_num_tokens.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + ) + + def test_get_context_size(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Context size in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 2048 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + # Test case 3: Context size NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + def test_get_max_chunks(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Max chunks in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 10 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 + + # Test case 3: Max chunks NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py new file mode 100644 index 0000000000..b1aca9baa3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel + + +class TestTTSModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def tts_model(self, mock_plugin_model_provider): + return TTSModel( + tenant_id="tenant_123", + model_type=ModelType.TTS, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, tts_model): + assert tts_model.model_type == ModelType.TTS + + def test_invoke_success(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + user=user, + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_success_no_user(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_exception(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_tts_model_voices(self, tts_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + language = "en-US" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}] + + result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language) + + assert result == [{"name": "Voice1"}] + mock_client.get_tts_model_voices.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + language=language, + ) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py new file mode 100644 index 0000000000..dde6ea02b5 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock, patch + +import dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module +from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + + +class TestGPT2Tokenizer: + def setup_method(self): + # Reset the global tokenizer before each test to ensure we test initialization + gpt2_tokenizer_module._tokenizer = None + + def test_get_encoder_tiktoken(self): + """ + Test that get_encoder successfully uses tiktoken when available. + """ + mock_encoding = MagicMock() + # Mock tiktoken to be sure it's used + with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_get_encoding.assert_called_once_with("gpt2") + + # Verify singleton behavior within the same test + encoder2 = GPT2Tokenizer.get_encoder() + assert encoder2 is encoder + assert mock_get_encoding.call_count == 1 + + def test_get_encoder_tiktoken_fallback(self): + """ + Test that get_encoder falls back to transformers when tiktoken fails. + """ + # patch tiktoken.get_encoding to raise an exception + with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): + # patch transformers.GPT2Tokenizer + with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: + mock_transformer_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_transformer_tokenizer + + with patch( + "dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" + ) as mock_logger: + encoder = GPT2Tokenizer.get_encoder() + + assert encoder == mock_transformer_tokenizer + mock_from_pretrained.assert_called_once() + mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") + + def test_get_num_tokens(self): + """ + Test get_num_tokens returns the correct count. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2, 3, 4, 5] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer.get_num_tokens("test text") + assert tokens_count == 5 + mock_encoder.encode.assert_called_once_with("test text") + + def test_get_num_tokens_by_gpt2_direct(self): + """ + Test _get_num_tokens_by_gpt2 directly. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") + assert tokens_count == 2 + mock_encoder.encode.assert_called_once_with("hello") + + def test_get_encoder_already_initialized(self): + """ + Test that if _tokenizer is already set, it returns it immediately. + """ + mock_existing_tokenizer = MagicMock() + gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer + + # Tiktoken should not be called if already initialized + with patch("tiktoken.get_encoding") as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_existing_tokenizer + mock_get_encoding.assert_not_called() + + def test_get_encoder_thread_safety(self): + """ + Simple test to ensure the lock is used. + """ + mock_encoding = MagicMock() + with patch("tiktoken.get_encoding", return_value=mock_encoding): + # We patch the lock in the module + with patch("dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py new file mode 100644 index 0000000000..1ad0210375 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py @@ -0,0 +1,522 @@ +import logging +from datetime import datetime +from threading import Lock +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +import contexts +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _provider_entity( + *, + provider: str, + supported_model_types: list[ModelType] | None = None, + models: list[AIModelEntity] | None = None, + icon_small: I18nObject | None = None, + icon_small_dark: I18nObject | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider), + supported_model_types=supported_model_types or [ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + icon_small=icon_small, + icon_small_dark=icon_small_dark, + ) + + +def _plugin_provider( + *, plugin_id: str, declaration: ProviderEntity, provider: str = "provider" +) -> PluginModelProviderEntity: + return PluginModelProviderEntity.model_construct( + id=f"{plugin_id}-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider=provider, + tenant_id="tenant", + plugin_unique_identifier=f"{plugin_id}-uid", + plugin_id=plugin_id, + declaration=declaration, + ) + + +@pytest.fixture(autouse=True) +def _reset_plugin_model_provider_context() -> None: + contexts.plugin_model_providers_lock.set(Lock()) + contexts.plugin_model_providers.set(None) + + +@pytest.fixture +def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + manager = MagicMock() + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager) + return manager + + +@pytest.fixture +def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory: + return ModelProviderFactory(tenant_id="tenant") + + +def test_get_plugin_model_providers_initializes_context_on_lookup_error( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + original_get = contexts.plugin_model_providers.get + calls = {"n": 0} + + def flaky_get() -> Any: + calls["n"] += 1 + if calls["n"] == 1: + raise LookupError + return original_get() + + monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get) + + providers = factory.get_plugin_model_providers() + assert len(providers) == 1 + assert providers[0].declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_providers_caches_and_does_not_refetch( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + first = factory.get_plugin_model_providers() + second = factory.get_plugin_model_providers() + + assert first is second + fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant") + + +def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + d1 = _provider_entity(provider="openai") + d2 = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=d1), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=d2), + ] + + providers = factory.get_providers() + assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"] + + +def test_get_plugin_model_provider_converts_short_provider_id( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + provider = factory.get_plugin_model_provider("openai") + assert provider.declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_provider_raises_on_invalid_provider( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + with pytest.raises(ValueError, match="Invalid provider"): + factory.get_plugin_model_provider("langgenius/unknown/unknown") + + +def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + schema = factory.get_provider_schema("openai") + assert schema.provider == "langgenius/openai/openai" + + +def test_provider_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"x": "y"}) + + +def test_provider_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator", + lambda _: fake_validator, + ) + + filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True}) + assert filtered == {"filtered": True} + fake_plugin_manager.validate_provider_credentials.assert_called_once() + kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["credentials"] == {"filtered": True} + + +def test_model_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"} + ) + + +def test_model_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator", + lambda *_: fake_validator, + ) + + filtered = factory.model_credentials_validate( + provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True} + ) + assert filtered == {"filtered": True} + kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "text-embedding" + assert kwargs["model"] == "m" + assert kwargs["credentials"] == {"filtered": True} + + +def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None: + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = model_schema.model_dump_json().encode() + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"}) + == model_schema + ) + + +def test_get_model_schema_cache_invalid_json_deletes_key( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert mock_redis.delete.called + assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_delete_redis_error_is_logged( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + mock_redis.delete.side_effect = RedisError("nope") + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_redis_get_error_falls_back_to_plugin( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.plugin_model_manager.get_model_schema.return_value = None + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.side_effect = RedisError("down") + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + factory.plugin_model_manager.get_model_schema.return_value = model_schema + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RedisError("nope") + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + == model_schema + ) + assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records) + + +@pytest.mark.parametrize( + ("model_type", "expected_class"), + [ + (ModelType.LLM, "LargeLanguageModel"), + (ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"), + (ModelType.RERANK, "RerankModel"), + (ModelType.SPEECH2TEXT, "Speech2TextModel"), + (ModelType.MODERATION, "ModerationModel"), + (ModelType.TTS, "TTSModel"), + ], +) +def test_get_model_type_instance_dispatches_by_type( + factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + sentinel = object() + monkeypatch.setattr( + f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}", + MagicMock(model_validate=lambda _: sentinel), + ) + + assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel + + +def test_get_model_type_instance_raises_on_unsupported( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + class UnknownModelType: + pass + + with pytest.raises(ValueError, match="Unsupported model type"): + factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type] + + +def test_get_models_filters_by_provider_and_model_type( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + llm = AIModelEntity( + model="m1", + label=I18nObject(en_US="m1"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + embed = AIModelEntity( + model="e1", + label=I18nObject(en_US="e1"), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + openai = _provider_entity( + provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed] + ) + anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm]) + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + # ModelType filter picks only matching models + providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert [m.model for m in providers[0].models] == ["e1"] + + # Provider filter excludes others + providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/anthropic/anthropic" + + +def test_get_models_provider_filter_skips_non_matching( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + openai = _provider_entity(provider="openai") + anthropic = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM) + assert providers == [] + + +def test_get_provider_icon_fetches_asset_and_returns_mime_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + assert tenant_id == "tenant" + return f"bytes:{id}".encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, mime = factory.get_provider_icon("openai", "icon_small", "en_US") + assert data == b"bytes:icon.png" + assert mime == "image/png" + + data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans") + assert data == b"bytes:dark-zh.svg" + assert mime == "image/svg+xml" + + +def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + return id.encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans") + assert data == b"icon-zh.png" + + data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US") + assert data == b"dark-en.svg" + + +def test_get_provider_icon_raises_for_missing_icons( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity(provider="langgenius/openai/openai") + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + with pytest.raises(ValueError, match="does not have small icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + with pytest.raises(ValueError, match="does not have small dark icon"): + factory.get_provider_icon("openai", "icon_small_dark", "en_US") + + +def test_get_provider_icon_raises_for_unsupported_icon_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="Unsupported icon type"): + factory.get_provider_icon("openai", "nope", "en_US") + + +def test_get_provider_icon_raises_when_file_name_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="does not have icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + +def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case( + factory: ModelProviderFactory, +) -> None: + plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google") + assert plugin_id == "langgenius/gemini" + assert provider_name == "google" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py new file mode 100644 index 0000000000..6d52457c8c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py @@ -0,0 +1,201 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormOption, + FormShowOnObject, + FormType, +) +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator + + +class TestCommonValidator: + def test_validate_credential_form_schema_required_missing(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + with pytest.raises(ValueError, match="Variable api_key is required"): + validator._validate_credential_form_schema(schema, {}) + + def test_validate_credential_form_schema_not_required_missing_with_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + required=False, + default="default_value", + ) + assert validator._validate_credential_form_schema(schema, {}) == "default_value" + + def test_validate_credential_form_schema_not_required_missing_no_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False + ) + assert validator._validate_credential_form_schema(schema, {}) is None + + def test_validate_credential_form_schema_max_length_exceeded(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 + ) + with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): + validator._validate_credential_form_schema(schema, {"api_key": "123456"}) + + def test_validate_credential_form_schema_not_string(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) + with pytest.raises(ValueError, match="Variable api_key should be string"): + validator._validate_credential_form_schema(schema, {"api_key": 123}) + + def test_validate_credential_form_schema_select_invalid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + with pytest.raises(ValueError, match="Variable mode is not in options"): + validator._validate_credential_form_schema(schema, {"mode": "medium"}) + + def test_validate_credential_form_schema_select_valid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" + + def test_validate_credential_form_schema_switch_invalid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) + + def test_validate_credential_form_schema_switch_valid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True + assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False + + def test_validate_and_filter_credential_form_schemas_with_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="auth_type", + label=I18nObject(en_US="Auth Type"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="API Key"), value="api_key"), + FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), + ], + ), + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ), + CredentialFormSchema( + variable="client_id", + label=I18nObject(en_US="Client ID"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="oauth")], + ), + ] + + # Case 1: auth_type = api_key + credentials = {"auth_type": "api_key", "api_key": "my_secret"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "auth_type" in result + assert "api_key" in result + assert "client_id" not in result + assert result["api_key"] == "my_secret" + + # Case 2: auth_type = oauth + credentials = {"auth_type": "oauth", "client_id": "my_client"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. + # Since 'oauth' is not an empty string, it is in result. + assert "auth_type" in result + assert "api_key" not in result + assert "client_id" in result + assert result["client_id"] == "my_client" + + def test_validate_and_filter_show_on_missing_variable(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is missing in credentials, so api_key should be filtered out + result = validator._validate_and_filter_credential_form_schemas(schemas, {}) + assert result == {} + + def test_validate_and_filter_show_on_mismatch_value(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is oauth, which doesn't match show_on + result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) + assert result == {} + + def test_validate_and_filter_multiple_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="target", + label=I18nObject(en_US="Target"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], + ) + ] + # Both match + assert "target" in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "b", "target": "val"} + ) + # One mismatch + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "c", "target": "val"} + ) + # One missing + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "target": "val"} + ) + + def test_validate_and_filter_skips_falsy_results(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), + CredentialFormSchema( + variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False + ), + ] + # Result of false switch is False. if result: is false. Not added. + # Result of empty string is "", if result: is false. Not added. + credentials = {"enabled": "false", "empty_str": ""} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "enabled" not in result + assert "empty_str" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py new file mode 100644 index 0000000000..bab2805276 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py @@ -0,0 +1,233 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormOption, + FormShowOnObject, + FormType, + ModelCredentialSchema, +) +from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator + + +def test_validate_and_filter_with_none_schema(): + validator = ModelCredentialSchemaValidator(ModelType.LLM, None) + with pytest.raises(ValueError, match="Model credential schema is None"): + validator.validate_and_filter({}) + + +def test_validate_and_filter_success(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ), + CredentialFormSchema( + variable="optional_field", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + default="default_val", + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + credentials = {"api_key": "sk-123456"} + result = validator.validate_and_filter(credentials) + + assert result["api_key"] == "sk-123456" + assert result["optional_field"] == "default_val" + assert credentials["__model_type"] == ModelType.LLM.value + + +def test_validate_and_filter_with_show_on(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=True, + show_on=[FormShowOnObject(variable="mode", value="advanced")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # mode is 'simple', conditional_field should be filtered out + credentials = {"mode": "simple", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert "conditional_field" not in result + assert result["mode"] == "simple" + + # mode is 'advanced', conditional_field should be kept + credentials = {"mode": "advanced", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert result["conditional_field"] == "secret" + assert result["mode"] == "advanced" + + # show_on variable missing in credentials + credentials = {"conditional_field": "secret"} # mode missing + with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema + validator.validate_and_filter(credentials) + + +def test_validate_and_filter_show_on_missing_trigger_var(): + # specifically test all_show_on_match = False when variable not in credentials + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional_trigger", + label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"), + type=FormType.TEXT_INPUT, + required=False, + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=False, + show_on=[FormShowOnObject(variable="optional_trigger", value="active")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # optional_trigger missing, conditional_field should be skipped + result = validator.validate_and_filter({"conditional_field": "val"}) + assert "conditional_field" not in result + + +def test_common_validator_logic_required(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({"api_key": ""}) + + +def test_common_validator_logic_max_length(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", + label=I18nObject(en_US="Key", zh_Hans="Key"), + type=FormType.TEXT_INPUT, + required=True, + max_length=5, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): + validator.validate_and_filter({"key": "123456"}) + + +def test_common_validator_logic_invalid_type(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key should be string"): + validator.validate_and_filter({"key": 123}) + + +def test_common_validator_logic_switch(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="enabled", + label=I18nObject(en_US="Enabled", zh_Hans="启用"), + type=FormType.SWITCH, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"enabled": "true"}) + assert result["enabled"] is True + + result = validator.validate_and_filter({"enabled": "false"}) + assert "enabled" not in result + + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator.validate_and_filter({"enabled": "not_a_bool"}) + + +def test_common_validator_logic_options(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="choice", + label=I18nObject(en_US="Choice", zh_Hans="选择"), + type=FormType.SELECT, + required=True, + options=[ + FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), + FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), + ], + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"choice": "a"}) + assert result["choice"] == "a" + + with pytest.raises(ValueError, match="Variable choice is not in options"): + validator.validate_and_filter({"choice": "c"}) + + +def test_validate_and_filter_optional_no_default(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({}) + assert "optional" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py new file mode 100644 index 0000000000..043306840f --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py @@ -0,0 +1,72 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema +from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) + + +class TestProviderCredentialSchemaValidator: + def test_validate_and_filter_success(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + required=False, + default="https://api.example.com", + ), + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test valid credentials + credentials = {"api_key": "my-secret-key"} + result = validator.validate_and_filter(credentials) + + assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} + + def test_validate_and_filter_missing_required(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test missing required credentials + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + def test_validate_and_filter_extra_fields_filtered(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test credentials with extra fields + credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} + result = validator.validate_and_filter(credentials) + + assert "api_key" in result + assert "extra_field" not in result + assert result == {"api_key": "my-secret-key"} + + def test_init(self): + schema = ProviderCredentialSchema(credential_form_schemas=[]) + validator = ProviderCredentialSchemaValidator(schema) + assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py new file mode 100644 index 0000000000..1ce8765a3b --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py @@ -0,0 +1,231 @@ +import dataclasses +import datetime +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path, PurePath +from re import compile +from typing import Any +from unittest.mock import MagicMock +from uuid import UUID + +import pytest +from pydantic import BaseModel, ConfigDict +from pydantic.networks import AnyUrl, NameEmail +from pydantic.types import SecretBytes, SecretStr +from pydantic_core import Url +from pydantic_extra_types.color import Color + +from dify_graph.model_runtime.utils.encoders import ( + _model_dump, + decimal_encoder, + generate_encoders_by_class_tuples, + isoformat, + jsonable_encoder, +) + + +class MockEnum(Enum): + A = "a" + B = "b" + + +class MockPydanticModel(BaseModel): + model_config = ConfigDict(populate_by_name=True) + name: str + age: int + + +@dataclasses.dataclass +class MockDataclass: + name: str + value: Any + + +class MockWithDict: + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data.items()) + + +class MockWithVars: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class TestEncoders: + def test_model_dump(self): + model = MockPydanticModel(name="test", age=20) + result = _model_dump(model) + assert result == {"name": "test", "age": 20} + + def test_isoformat(self): + d = datetime.date(2023, 1, 1) + assert isoformat(d) == "2023-01-01" + t = datetime.time(12, 0, 0) + assert isoformat(t) == "12:00:00" + + def test_decimal_encoder(self): + assert decimal_encoder(Decimal("1.0")) == 1.0 + assert decimal_encoder(Decimal(1)) == 1 + assert decimal_encoder(Decimal("1.5")) == 1.5 + assert decimal_encoder(Decimal(0)) == 0 + assert decimal_encoder(Decimal(-1)) == -1 + + def test_generate_encoders_by_class_tuples(self): + type_map = {int: str, float: str, str: int} + result = generate_encoders_by_class_tuples(type_map) + assert result[str] == (int, float) + assert result[int] == (str,) + + def test_jsonable_encoder_basic_types(self): + assert jsonable_encoder("string") == "string" + assert jsonable_encoder(123) == 123 + assert jsonable_encoder(1.23) == 1.23 + assert jsonable_encoder(None) is None + + def test_jsonable_encoder_pydantic(self): + model = MockPydanticModel(name="test", age=20) + assert jsonable_encoder(model) == {"name": "test", "age": 20} + + def test_jsonable_encoder_pydantic_root(self): + # Manually create a mock that behaves like a model with __root__ + # because Pydantic v2 handles root differently, but the code checks for "__root__" + model = MagicMock(spec=BaseModel) + # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) + model.model_dump.return_value = {"__root__": [1, 2, 3]} + assert jsonable_encoder(model) == [1, 2, 3] + + def test_jsonable_encoder_dataclass(self): + obj = MockDataclass(name="test", value=1) + assert jsonable_encoder(obj) == {"name": "test", "value": 1} + # Test dataclass type (should not be treated as instance) + # It should fall back to vars() or dict() or at least not crash + with pytest.raises(ValueError): + jsonable_encoder(MockDataclass) + + def test_jsonable_encoder_enum(self): + assert jsonable_encoder(MockEnum.A) == "a" + + def test_jsonable_encoder_path(self): + assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" + assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" + + def test_jsonable_encoder_decimal(self): + # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") + assert jsonable_encoder(Decimal("1.23")) == "1.23" + assert jsonable_encoder(Decimal("1.000")) == "1.000" + + def test_jsonable_encoder_dict(self): + d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]} + assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + + d_with_none = {"a": 1, "b": None} + assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} + assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} + + def test_jsonable_encoder_collections(self): + assert jsonable_encoder([1, 2]) == [1, 2] + assert jsonable_encoder((1, 2)) == [1, 2] + assert jsonable_encoder({1, 2}) == [1, 2] + assert jsonable_encoder(frozenset([1, 2])) == [1, 2] + assert jsonable_encoder(deque([1, 2])) == [1, 2] + + def gen(): + yield 1 + yield 2 + + assert jsonable_encoder(gen()) == [1, 2] + + def test_jsonable_encoder_custom_encoder(self): + custom = {int: lambda x: str(x + 1)} + assert jsonable_encoder(1, custom_encoder=custom) == "2" + + # Test subclass matching for custom encoder + class SubInt(int): + pass + + assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" + + def test_jsonable_encoder_special_types(self): + # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples + assert jsonable_encoder(b"bytes") == "bytes" + assert jsonable_encoder(Color("red")) == "red" + + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + assert jsonable_encoder(dt) == dt.isoformat() + + date = datetime.date(2023, 1, 1) + assert jsonable_encoder(date) == date.isoformat() + + time = datetime.time(12, 0, 0) + assert jsonable_encoder(time) == time.isoformat() + + td = datetime.timedelta(seconds=60) + assert jsonable_encoder(td) == 60.0 + + assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" + assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" + assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" + assert jsonable_encoder(IPv6Address("::1")) == "::1" + assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" + assert jsonable_encoder(IPv6Network("::/128")) == "::/128" + + assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " + + assert jsonable_encoder(compile("abc")) == "abc" + + # Secret types + # Check what they actually return in this environment + res_bytes = jsonable_encoder(SecretBytes(b"secret")) + assert "**********" in res_bytes + + res_str = jsonable_encoder(SecretStr("secret")) + assert res_str == "**********" + + u = UUID("12345678-1234-5678-1234-567812345678") + assert jsonable_encoder(u) == str(u) + + url = AnyUrl("https://example.com") + assert jsonable_encoder(url) == "https://example.com/" + + purl = Url("https://example.com") + assert jsonable_encoder(purl) == "https://example.com/" + + def test_jsonable_encoder_fallback(self): + # dict(obj) success + obj_dict = MockWithDict({"a": 1}) + assert jsonable_encoder(obj_dict) == {"a": 1} + + # vars(obj) success + obj_vars = MockWithVars(x=10, y=20) + assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} + + # error fallback + class ReallyUnserializable: + __slots__ = ["__weakref__"] # No __dict__ + + def __iter__(self): + raise TypeError("not iterable") + + with pytest.raises(ValueError) as exc: + jsonable_encoder(ReallyUnserializable()) + assert "not iterable" in str(exc.value) + + def test_jsonable_encoder_nested(self): + data = { + "model": MockPydanticModel(name="test", age=20), + "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], + "set": {1, 2}, + } + expected = { + "model": {"name": "test", "age": 20}, + "list": ["1.1", {"a": "/tmp"}], + "set": [1, 2], + } + assert jsonable_encoder(data) == expected