From b170eabaf3fc412e2c09481003bae188063187c2 Mon Sep 17 00:00:00 2001 From: Rajat Agarwal Date: Thu, 12 Mar 2026 09:18:13 +0530 Subject: [PATCH] test: Unit test cases for core.tools module (#32447) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: wangxiaolei Co-authored-by: akashseth-ifp Co-authored-by: mahammadasim <135003320+mahammadasim@users.noreply.github.com> --- api/core/tools/builtin_tool/provider.py | 17 +- api/core/tools/tool_file_manager.py | 1 + api/tests/unit_tests/core/tools/__init__.py | 0 .../core/tools/test_builtin_tool_base.py | 103 ++ .../core/tools/test_builtin_tool_provider.py | 216 +++++ .../core/tools/test_builtin_tools_extra.py | 310 ++++++ .../unit_tests/core/tools/test_custom_tool.py | 285 ++++++ .../core/tools/test_custom_tool_provider.py | 75 ++ .../core/tools/test_dataset_retriever_tool.py | 145 +++ .../unit_tests/core/tools/test_mcp_tool.py | 150 +++ .../core/tools/test_mcp_tool_provider.py | 73 ++ .../unit_tests/core/tools/test_plugin_tool.py | 91 ++ .../core/tools/test_plugin_tool_provider.py | 89 ++ .../unit_tests/core/tools/test_signature.py | 119 +++ .../unit_tests/core/tools/test_tool_engine.py | 280 ++++++ .../core/tools/test_tool_file_manager.py | 249 +++++ .../core/tools/test_tool_label_manager.py | 92 ++ .../core/tools/test_tool_manager.py | 899 ++++++++++++++++++ .../tools/test_tool_provider_controller.py | 110 +++ .../core/tools/utils/test_configuration.py | 148 +++ .../core/tools/utils/test_encryption.py | 43 +- .../core/tools/utils/test_misc_utils_extra.py | 478 ++++++++++ .../utils/test_model_invocation_utils.py | 158 +++ .../core/tools/utils/test_parser.py | 228 +++++ .../utils/test_system_oauth_encryption.py | 51 + .../utils/test_workflow_configuration_sync.py | 90 ++ .../tools/workflow_as_tool/test_provider.py | 196 ++++ .../core/tools/workflow_as_tool/test_tool.py | 508 ++++++---- api/tests/unit_tests/tools/__init__.py | 0 29 files changed, 5008 insertions(+), 196 deletions(-) create mode 100644 api/tests/unit_tests/core/tools/__init__.py create mode 100644 api/tests/unit_tests/core/tools/test_builtin_tool_base.py create mode 100644 api/tests/unit_tests/core/tools/test_builtin_tool_provider.py create mode 100644 api/tests/unit_tests/core/tools/test_builtin_tools_extra.py create mode 100644 api/tests/unit_tests/core/tools/test_custom_tool.py create mode 100644 api/tests/unit_tests/core/tools/test_custom_tool_provider.py create mode 100644 api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py create mode 100644 api/tests/unit_tests/core/tools/test_mcp_tool.py create mode 100644 api/tests/unit_tests/core/tools/test_mcp_tool_provider.py create mode 100644 api/tests/unit_tests/core/tools/test_plugin_tool.py create mode 100644 api/tests/unit_tests/core/tools/test_plugin_tool_provider.py create mode 100644 api/tests/unit_tests/core/tools/test_signature.py create mode 100644 api/tests/unit_tests/core/tools/test_tool_engine.py create mode 100644 api/tests/unit_tests/core/tools/test_tool_file_manager.py create mode 100644 api/tests/unit_tests/core/tools/test_tool_label_manager.py create mode 100644 api/tests/unit_tests/core/tools/test_tool_manager.py create mode 100644 api/tests/unit_tests/core/tools/test_tool_provider_controller.py create mode 100644 api/tests/unit_tests/core/tools/utils/test_configuration.py create mode 100644 api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py create mode 100644 api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py create mode 100644 api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py create mode 100644 api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py create mode 100644 api/tests/unit_tests/tools/__init__.py diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 50105bd707..20cdb3e57f 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.get_credentials_schema_by_type(CredentialType.API_KEY) - def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]: """ returns the credentials schema of the provider - :param credential_type: the type of the credential - :return: the credentials schema of the provider + :param credential_type: the type of the credential, as CredentialType or str; str values + are normalized via CredentialType.of and may raise ValueError for invalid values. + :return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an + empty list for CredentialType.UNAUTHORIZED or missing schemas. + + Reads from self.entity.oauth_schema and self.entity.credentials_schema. + Raises ValueError for invalid credential types. """ - if credential_type == CredentialType.OAUTH2.value: + if isinstance(credential_type, str): + credential_type = CredentialType.of(credential_type) + if credential_type == CredentialType.OAUTH2: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + if credential_type == CredentialType.UNAUTHORIZED: + return [] raise ValueError(f"Invalid credential type: {credential_type}") def get_oauth_client_schema(self) -> list[ProviderConfig]: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f6eccc734b..210f488afc 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -137,6 +137,7 @@ class ToolFileManager: session.add(tool_file) session.commit() + session.refresh(tool_file) return tool_file diff --git a/api/tests/unit_tests/core/tools/__init__.py b/api/tests/unit_tests/core/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py new file mode 100644 index 0000000000..f123f60a34 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + +class _BuiltinDummyTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +def _build_tool() -> _BuiltinDummyTool: + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) + + +def test_builtin_tool_fork_and_provider_type(): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, _BuiltinDummyTool) + assert forked.runtime.tenant_id == "tenant-2" + assert tool.tool_provider_type() == ToolProviderType.BUILT_IN + + +def test_invoke_model_calls_model_invocation_utils_invoke(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: + assert ( + tool.invoke_model( + user_id="u1", + prompt_messages=[UserPromptMessage(content="hello")], + stop=[], + ) + == "result" + ) + mock_invoke.assert_called_once() + + +def test_get_max_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + assert tool.get_max_tokens() == 4096 + + +def test_get_prompt_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + +def test_runtime_none_raises(): + tool = _build_tool() + tool.runtime = None + with pytest.raises(ValueError, match="runtime is required"): + tool.get_max_tokens() + with pytest.raises(ValueError, match="runtime is required"): + tool.get_prompt_tokens([UserPromptMessage(content="hello")]) + + +def test_builtin_tool_summary_short_and_long_content_paths(): + tool = _build_tool() + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=100): + with patch.object(_BuiltinDummyTool, "get_prompt_tokens", return_value=10): + assert tool.summary(user_id="u1", content="short") == "short" + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=10): + with patch.object( + _BuiltinDummyTool, + "get_prompt_tokens", + side_effect=lambda prompt_messages: len(prompt_messages[-1].content), + ): + with patch.object( + _BuiltinDummyTool, + "invoke_model", + return_value=SimpleNamespace(message=SimpleNamespace(content="S")), + ): + result = tool.summary(user_id="u1", content="x" * 30 + "\n" + "y" * 5) + + assert result + assert "S" in result diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py new file mode 100644 index 0000000000..ad6d5906ae --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderEntity, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError + + +class _FakeBuiltinTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _ConcreteBuiltinProvider(BuiltinToolProviderController): + last_validation: tuple[str, dict[str, Any]] | None = None + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + self.last_validation = (user_id, credentials) + + +def _provider_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": ["utilities"], + }, + "credentials_for_provider": { + "api_key": { + "type": "secret-input", + "required": True, + } + }, + "oauth_schema": { + "client_schema": [ + { + "name": "client_id", + "type": "text-input", + } + ], + "credentials_schema": [ + { + "name": "access_token", + "type": "secret-input", + } + ], + }, + } + + +def _tool_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "tool_a", + "label": {"en_US": "Tool A"}, + }, + "parameters": [], + } + + +def test_builtin_tool_provider_init_load_tools_and_basic_accessors(monkeypatch): + yaml_payloads = [_provider_yaml(), _tool_yaml()] + + def _load_yaml(*args, **kwargs): + return yaml_payloads.pop(0) + + monkeypatch.setattr("core.tools.builtin_tool.provider.load_yaml_file_cached", _load_yaml) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.listdir", + lambda *args, **kwargs: ["tool_a.yaml", "__init__.py", "readme.md"], + ) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + lambda *args, **kwargs: _FakeBuiltinTool, + ) + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema() + assert provider.get_tools() + assert provider.get_tool("tool_a") is not None + assert provider.get_tool("missing") is None + assert provider.provider_type == ToolProviderType.BUILT_IN + assert provider.tool_labels == ["utilities"] + assert provider.need_credentials is True + + oauth_schema = provider.get_credentials_schema_by_type(CredentialType.OAUTH2) + assert len(oauth_schema) == 1 + api_schema = provider.get_credentials_schema_by_type(CredentialType.API_KEY) + assert len(api_schema) == 1 + assert provider.get_oauth_client_schema()[0].name == "client_id" + assert set(provider.get_supported_credential_types()) == {CredentialType.API_KEY, CredentialType.OAUTH2} + + +def test_builtin_tool_provider_invalid_credential_type_raises(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + with pytest.raises(ValueError, match="Invalid credential type: invalid"): + provider.get_credentials_schema_by_type("invalid") + + +def test_builtin_tool_provider_validate_credentials_delegates(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + provider.validate_credentials("user-1", {"api_key": "secret"}) + assert provider.last_validation == ("user-1", {"api_key": "secret"}) + + +def test_builtin_tool_provider_unauthorized_schema_is_empty(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema_by_type(CredentialType.UNAUTHORIZED) == [] + + +def test_builtin_tool_provider_init_raises_when_provider_yaml_missing(): + with patch("core.tools.builtin_tool.provider.load_yaml_file_cached", side_effect=RuntimeError("boom")): + with pytest.raises(ToolProviderNotFoundError, match="can not load provider yaml"): + _ConcreteBuiltinProvider() + + +def test_builtin_tool_provider_handles_empty_credentials_and_oauth(): + provider = object.__new__(_ConcreteBuiltinProvider) + provider.tools = [] + provider.entity = ToolProviderEntity.model_validate( + { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": None, + }, + "credentials_schema": [], + "oauth_schema": None, + }, + ) + + assert provider.get_oauth_client_schema() == [] + assert provider.get_supported_credential_types() == [] + assert provider.need_credentials is False + assert provider._get_tool_labels() == [] + + +def test_builtin_tool_provider_forked_tool_runtime_is_initialized(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + tool = provider.get_tool("tool_a") + assert tool is not None + assert isinstance(tool.runtime, ToolRuntime) + assert tool.runtime.tenant_id == "" + tool.runtime.invoke_from = InvokeFrom.DEBUGGER + assert tool.runtime.invoke_from == InvokeFrom.DEBUGGER diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py new file mode 100644 index 0000000000..62cfb6ce5b --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import math +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort +from core.tools.builtin_tool.providers.audio.audio import AudioToolProvider +from core.tools.builtin_tool.providers.audio.tools.asr import ASRTool +from core.tools.builtin_tool.providers.audio.tools.tts import TTSTool +from core.tools.builtin_tool.providers.code.code import CodeToolProvider +from core.tools.builtin_tool.providers.code.tools.simple_code import SimpleCode +from core.tools.builtin_tool.providers.time.time import WikiPediaProvider +from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool +from core.tools.builtin_tool.providers.time.tools.localtime_to_timestamp import LocaltimeToTimestampTool +from core.tools.builtin_tool.providers.time.tools.timestamp_to_localtime import TimestampToLocaltimeTool +from core.tools.builtin_tool.providers.time.tools.timezone_conversion import TimezoneConversionTool +from core.tools.builtin_tool.providers.time.tools.weekday import WeekdayTool +from core.tools.builtin_tool.providers.webscraper.tools.webscraper import WebscraperTool +from core.tools.builtin_tool.providers.webscraper.webscraper import WebscraperProvider +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from dify_graph.file.enums import FileType +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey + + +def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return tool_cls(provider="provider-a", entity=entity, runtime=runtime) + + +def _raise_runtime_error(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("boom") + + +def test_current_time_tool(): + current_tool = _build_builtin_tool(CurrentTimeTool) + utc_text = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "UTC"}))[0].message.text + assert utc_text + + invalid_tz = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "Invalid/TZ"}))[0].message.text + assert "Invalid timezone" in invalid_tz + + +def test_localtime_to_timestamp_tool(): + localtime_tool = _build_builtin_tool(LocaltimeToTimestampTool) + ts_message = list( + localtime_tool.invoke(user_id="u", tool_parameters={"localtime": "2024-01-01 10:00:00", "timezone": "UTC"}) + )[0].message.text + ts_value = float(ts_message.strip()) + assert math.isfinite(ts_value) + assert ts_value >= 0 + with pytest.raises(ToolInvokeError): + LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC") + + +def test_timestamp_to_localtime_tool(): + to_local_tool = _build_builtin_tool(TimestampToLocaltimeTool) + local_text = list(to_local_tool.invoke(user_id="u", tool_parameters={"timestamp": 1704067200, "timezone": "UTC"}))[ + 0 + ].message.text + assert "2024" in local_text + with pytest.raises(ToolInvokeError): + TimestampToLocaltimeTool.timestamp_to_localtime("bad", "UTC") # type: ignore[arg-type] + + +def test_timezone_conversion_tool(): + timezone_tool = _build_builtin_tool(TimezoneConversionTool) + converted = list( + timezone_tool.invoke( + user_id="u", + tool_parameters={ + "current_time": "2024-01-01 08:00:00", + "current_timezone": "UTC", + "target_timezone": "Asia/Tokyo", + }, + ) + )[0].message.text + assert converted.startswith("2024-01-01") + with pytest.raises(ToolInvokeError): + TimezoneConversionTool.timezone_convert("bad", "UTC", "Asia/Tokyo") + + +def test_weekday_tool(): + weekday_tool = _build_builtin_tool(WeekdayTool) + valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text + assert "January 1, 2024" in valid + invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ + 0 + ].message.text + assert "Invalid date" in invalid + with pytest.raises(ValueError, match="Month is required"): + list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "day": 1})) + + +def test_simple_code_valid_execution(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + lambda *a: "ok", + ) + result = list( + simple_code.invoke( + user_id="u", + tool_parameters={"language": "python3", "code": "print(1)"}, + ) + )[0].message.text + assert result == "ok" + + +def test_simple_code_invalid_language(): + simple_code = _build_builtin_tool(SimpleCode) + + with pytest.raises(ValueError, match="Only python3 and javascript"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "go", "code": "fmt.Println(1)"})) + + +def test_simple_code_execution_error(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "python3", "code": "print(1)"})) + + +def test_webscraper_empty_url(): + webscraper = _build_builtin_tool(WebscraperTool) + empty = list(webscraper.invoke(user_id="u", tool_parameters={"url": ""}))[0].message.text + assert empty == "Please input url" + + +def test_webscraper_fetch(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + full = list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))[0].message.text + assert full == "page" + + +def test_webscraper_summary(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + monkeypatch.setattr(webscraper, "summary", lambda user_id, content: "summary") + summarized = list( + webscraper.invoke( + user_id="u", + tool_parameters={"url": "https://example.com", "generate_summary": True}, + ) + )[0].message.text + assert summarized == "summary" + + +def test_webscraper_fetch_error(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr( + "core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"})) + + +def test_asr_invalid_file(): + asr = _build_builtin_tool(ASRTool) + file_obj = SimpleNamespace(type=FileType.DOCUMENT) + invalid_file = list(asr.invoke(user_id="u", tool_parameters={"audio_file": file_obj}))[0].message.text + assert "not a valid audio file" in invalid_file + + +def test_asr_valid_file_invocation(monkeypatch): + asr = _build_builtin_tool(ASRTool) + model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + audio_file = SimpleNamespace(type=FileType.AUDIO) + ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text + assert ok == "transcript" + + +def test_asr_available_models_and_runtime_parameters(monkeypatch): + asr = _build_builtin_tool(ASRTool) + provider_model = type("PM", (), {"provider": "p", "models": [type("Model", (), {"model": "m"})()]})() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelProviderService.get_models_by_model_type", + lambda *a, **k: [provider_model], + ) + assert asr.get_available_models() == [("p", "m")] + assert asr.get_runtime_parameters()[0].name == "model" + + +def test_tts_invoke_returns_messages(monkeypatch): + tts = _build_builtin_tool(TTSTool) + voices_model_instance = type( + "TTSM", + (), + { + "get_tts_voices": lambda self: [{"value": "voice-1"}], + "invoke_tts": lambda self, **kwargs: [b"a", b"b"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + ) + messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + + +def test_tts_get_available_models_requires_runtime(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + tts.get_available_models() + + +def test_tts_tool_raises_when_runtime_missing(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +@pytest.mark.parametrize( + "voices", + [[{"value": None}], []], +) +def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): + tts = _build_builtin_tool(TTSTool) + tts.runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + model_without_voice = type( + "TTSModelNoVoice", + (), + { + "get_tts_voices": lambda self: voices, + "invoke_tts": lambda self, **kwargs: [b"x"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + ) + with pytest.raises(ValueError, match="no voice available"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +def test_tts_tool_get_available_models_and_runtime_parameters(monkeypatch): + tts = _build_builtin_tool(TTSTool) + + model_1 = SimpleNamespace( + model="model-a", + model_properties={ModelPropertyKey.VOICES: [{"mode": "v1", "name": "Voice 1"}]}, + ) + model_2 = SimpleNamespace(model="model-b", model_properties={}) + provider_models = [SimpleNamespace(provider="provider-a", models=[model_1, model_2])] + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelProviderService.get_models_by_model_type", + lambda *args, **kwargs: provider_models, + ) + + available_models = tts.get_available_models() + assert available_models == [ + ("provider-a", "model-a", [{"mode": "v1", "name": "Voice 1"}]), + ("provider-a", "model-b", []), + ] + + runtime_parameters = tts.get_runtime_parameters() + assert runtime_parameters[0].name == "model" + assert runtime_parameters[0].required is True + assert runtime_parameters[0].options[0].value == "provider-a#model-a" + assert runtime_parameters[1].name == "voice#provider-a#model-a" + + +def test_provider_classes_and_builtin_sort(monkeypatch): + # Use object.__new__ to avoid YAML-loading __init__; only pass-through validation is exercised. + # Ensure pass-through _validate_credentials methods are executed. + AudioToolProvider._validate_credentials(object.__new__(AudioToolProvider), "u", {}) + CodeToolProvider._validate_credentials(object.__new__(CodeToolProvider), "u", {}) + WikiPediaProvider._validate_credentials(object.__new__(WikiPediaProvider), "u", {}) + WebscraperProvider._validate_credentials(object.__new__(WebscraperProvider), "u", {}) + + providers = [SimpleNamespace(name="b"), SimpleNamespace(name="a")] + monkeypatch.setattr(BuiltinToolProviderSort, "_position", {}) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.get_tool_position_map", + lambda _: {"a": 0, "b": 1}, + ) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.sort_by_position_map", + lambda position, values, name_func: sorted(values, key=lambda x: name_func(x)), + ) + sorted_providers = BuiltinToolProviderSort.sort(providers) + assert [p.name for p in sorted_providers] == ["a", "b"] diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py new file mode 100644 index 0000000000..79b8eaaa87 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import httpx +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.custom_tool.tool import ApiTool, ParsedResponse +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError + + +def _build_tool(*, openapi: dict | None = None) -> ApiTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + bundle = ApiToolBundle( + server_url="https://api.example.com/items/{id}", + method="GET", + summary="summary", + operation_id="op-id", + parameters=[], + author="author", + openapi=openapi or {"parameters": []}, + ) + runtime = ToolRuntime( + tenant_id="tenant-1", + invoke_from=InvokeFrom.DEBUGGER, + credentials={"auth_type": "api_key_header", "api_key_value": "k"}, + ) + return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="provider-id") + + +def test_parsed_response_to_string(): + assert ParsedResponse({"a": 1}, True).to_string() == '{"a": 1}' + assert ParsedResponse("ok", False).to_string() == "ok" + + +def test_api_tool_fork_runtime_and_validate_credentials(monkeypatch): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, ApiTool) + assert forked.runtime.tenant_id == "tenant-2" + + tool.api_bundle = None # type: ignore[assignment] + with pytest.raises(ValueError, match="api_bundle is required"): + tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + + tool = _build_tool() + assert tool.validate_credentials(credentials={}, parameters={}, format_only=True) == "" + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {"Authorization": "Bearer x"}) + monkeypatch.setattr( + tool, + "do_http_request", + lambda url, method, headers, parameters: httpx.Response(200, json={"ok": True}), + ) + result = tool.validate_credentials(credentials={}, parameters={"a": 1}, format_only=False) + assert result == '{"ok": true}' + + +def test_assembling_request_auth_header_assembly(): + tool = _build_tool() + + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "k" + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "bearer", + "api_key_value": "abc", + } + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Bearer abc" + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_header_prefix": "basic", "api_key_value": "abc"} + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Basic abc" + + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_value": "abc"} + assert tool.assembling_request(parameters={}) == {} + + +def test_assembling_request_runtime_auth_errors(): + tool = _build_tool() + + tool.runtime = None + with pytest.raises(ToolProviderCredentialValidationError, match="runtime not initialized"): + tool.assembling_request(parameters={}) + + tool.runtime = ToolRuntime(tenant_id="tenant", credentials={}) + with pytest.raises(ToolProviderCredentialValidationError, match="Missing auth_type"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header"} + with pytest.raises(ToolProviderCredentialValidationError, match="Missing api_key_value"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": 123} + with pytest.raises(ToolProviderCredentialValidationError, match="must be a string"): + tool.assembling_request(parameters={}) + + +def test_assembling_request_parameter_validation_and_defaults(): + tool = _build_tool() + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "x"} + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default=None), + ] + with pytest.raises(ToolParameterValidationError, match="Missing required parameter required_param"): + tool.assembling_request(parameters={}) + + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default="d"), + ] + params = {} + tool.assembling_request(parameters=params) + assert params["required_param"] == "d" + + +def test_validate_and_parse_response_branches(): + tool = _build_tool() + + with pytest.raises(ToolInvokeError, match="status code 500"): + tool.validate_and_parse_response(httpx.Response(500, text="boom")) + + empty = tool.validate_and_parse_response(httpx.Response(200, content=b"")) + assert empty.is_json is False + assert "Empty response from the tool" in str(empty.content) + + json_resp = tool.validate_and_parse_response( + httpx.Response(200, json={"a": 1}, headers={"content-type": "application/json"}) + ) + assert json_resp.is_json is True + assert json_resp.content == {"a": 1} + + non_json_type = tool.validate_and_parse_response( + httpx.Response(200, text='{"a": 1}', headers={"content-type": "text/plain"}) + ) + assert non_json_type.is_json is False + assert non_json_type.content == '{"a": 1}' + + plain_resp = tool.validate_and_parse_response(httpx.Response(200, text="plain")) + assert plain_resp.is_json is False + assert plain_resp.content == "plain" + + with pytest.raises(ValueError, match="Invalid response type"): + tool.validate_and_parse_response("invalid") # type: ignore[arg-type] + + +def test_get_parameter_value_and_type_conversion_helpers(): + tool = _build_tool() + + assert tool.get_parameter_value({"name": "x"}, {"x": 1}) == 1 + assert tool.get_parameter_value({"name": "x", "required": False, "schema": {"default": "d"}}, {}) == "d" + with pytest.raises(ToolParameterValidationError, match="Missing required parameter x"): + tool.get_parameter_value({"name": "x", "required": True}, {}) + + assert tool._convert_body_property_any_of({}, "12", [{"type": "integer"}]) == 12 + assert tool._convert_body_property_any_of({}, "1.5", [{"type": "number"}]) == 1.5 + assert tool._convert_body_property_any_of({}, "true", [{"type": "boolean"}]) is True + assert tool._convert_body_property_any_of({}, "", [{"type": "null"}]) is None + assert tool._convert_body_property_any_of({}, "x", [{"anyOf": [{"type": "string"}]}]) == "x" + + assert tool._convert_body_property_type({"type": "integer"}, "1") == 1 + assert tool._convert_body_property_type({"type": "number"}, "1.2") == 1.2 + assert tool._convert_body_property_type({"type": "string"}, 1) == "1" + assert tool._convert_body_property_type({"type": "boolean"}, 1) is True + assert tool._convert_body_property_type({"type": "null"}, None) is None + assert tool._convert_body_property_type({"type": "object"}, '{"a":1}') == {"a": 1} + assert tool._convert_body_property_type({"type": "array"}, "[1,2]") == [1, 2] + assert tool._convert_body_property_type({"type": "invalid"}, "v") == "v" + assert tool._convert_body_property_type({"anyOf": [{"type": "integer"}]}, "2") == 2 + + +def test_do_http_request_builds_arguments_and_handles_invalid_method(monkeypatch): + openapi = { + "parameters": [ + {"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}, + {"name": "q", "in": "query", "required": False, "schema": {"default": ""}}, + {"name": "X-Extra", "in": "header", "required": False, "schema": {"default": "x"}}, + {"name": "sid", "in": "cookie", "required": False, "schema": {"default": "cookie1"}}, + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["count"], + "properties": { + "count": {"type": "integer"}, + "name": {"type": "string", "default": "n"}, + }, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_query_param": "key", "api_key_value": "v"} + headers = {} + captured = {} + + def _fake_get(url, **kwargs): + captured["url"] = url + captured["kwargs"] = kwargs + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.get", _fake_get) + response = tool.do_http_request( + "https://api.example.com/items/{id}", + "GET", + headers=headers, + parameters={"id": "123", "count": "2", "q": "search"}, + ) + + assert isinstance(response, httpx.Response) + assert captured["url"].endswith("/items/123") + assert captured["kwargs"]["params"]["q"] == "search" + assert captured["kwargs"]["params"]["key"] == "v" + assert captured["kwargs"]["headers"]["Content-Type"] == "application/json" + + invalid_method_tool = _build_tool(openapi={"parameters": []}) + with pytest.raises(ValueError, match="Invalid http method"): + invalid_method_tool.do_http_request("https://api.example.com", "TRACE", headers={}, parameters={}) + + +def test_do_http_request_handles_file_upload_and_invoke_paths(monkeypatch): + openapi = { + "parameters": [], + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": {"file": {"format": "binary"}}, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "k"} + fake_file = SimpleNamespace(filename="a.txt", mime_type="text/plain") + captured = {} + + def _fake_post(url, **kwargs): + captured["headers"] = kwargs["headers"] + captured["files"] = kwargs["files"] + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.download", lambda _: b"file-bytes") + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.post", _fake_post) + response = tool.do_http_request( + "https://api.example.com/upload", + "POST", + headers={}, + parameters={"file": fake_file}, + ) + assert isinstance(response, httpx.Response) + assert "Content-Type" not in captured["headers"] + assert captured["files"][0][0] == "file" + + # _invoke JSON path + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {}) + monkeypatch.setattr(tool, "do_http_request", lambda *args, **kwargs: httpx.Response(200, text='{"a":1}')) + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse({"a": 1}, True)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.TEXT] + + # _invoke text path + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse("plain", False)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert len(messages) == 1 + assert messages[0].message.text == "plain" diff --git a/api/tests/unit_tests/core/tools/test_custom_tool_provider.py b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py new file mode 100644 index 0000000000..93ae217e24 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolProviderType + + +def _db_provider() -> SimpleNamespace: + bundle = ApiToolBundle( + server_url="https://api.example.com/items", + method="GET", + summary="List items", + operation_id="list_items", + parameters=[], + author="author", + openapi={"parameters": []}, + ) + return SimpleNamespace( + id="provider-id", + tenant_id="tenant-1", + name="provider-a", + description="desc", + icon="icon.svg", + user=SimpleNamespace(name="Alice"), + tools=[bundle], + ) + + +def test_api_tool_provider_from_db_and_parse_tool_bundle(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_HEADER) + assert controller.provider_type == ToolProviderType.API + assert any(c.name == "api_key_value" for c in controller.entity.credentials_schema) + + tool = controller._parse_tool_bundle(_db_provider().tools[0]) + assert isinstance(tool, ApiTool) + assert tool.entity.identity.provider == "provider-id" + + +def test_api_tool_provider_from_db_query_auth_and_none_auth(): + query_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_QUERY) + assert any(c.name == "api_key_query_param" for c in query_controller.entity.credentials_schema) + + none_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + assert [c.name for c in none_controller.entity.credentials_schema] == ["auth_type"] + + +def test_api_tool_provider_load_get_tools_and_get_tool(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + loaded = controller.load_bundled_tools(_db_provider().tools) + assert len(loaded) == 1 + + assert isinstance(controller.get_tool("list_items"), ApiTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + # Return cached tools without querying database. + cached = controller.get_tools("tenant-1") + assert len(cached) == 1 + + # Force DB fetch branch. + controller.tools = [] + provider_with_tools = _db_provider() + with patch("core.tools.custom_tool.provider.db") as mock_db: + scalars_result = Mock() + scalars_result.all.return_value = [provider_with_tools] + mock_db.session.scalars.return_value = scalars_result + tools = controller.get_tools("tenant-1") + assert len(tools) == 1 diff --git a/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py new file mode 100644 index 0000000000..23c0be9487 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py @@ -0,0 +1,145 @@ +"""Unit tests for DatasetRetrieverTool behavior and retrieval wiring.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE) + + +def test_get_dataset_tools_returns_empty_for_empty_dataset_ids() -> None: + # Arrange + retrieve_config = _retrieve_config() + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=[], + retrieve_config=retrieve_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_returns_empty_for_missing_retrieve_config() -> None: + # Arrange + dataset_ids = ["d1"] + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=dataset_ids, + retrieve_config=None, # type: ignore[arg-type] + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_builds_tool_and_restores_strategy() -> None: + # Arrange + retrieve_config = _retrieve_config() + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + + # Act + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=retrieve_config, + return_resource=True, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={"x": 1}, + ) + + # Assert + assert len(tools) == 1 + assert tools[0].entity.identity.name == "dataset_tool" + assert retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + + +def _build_dataset_tool() -> tuple[DatasetRetrieverTool, SimpleNamespace]: + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=_retrieve_config(), + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + return tools[0], retrieval_tool + + +def test_runtime_parameters_shape() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + params = tool.get_runtime_parameters() + + # Assert + assert len(params) == 1 + assert params[0].name == "query" + + +def test_empty_query_behavior() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + empty_query = list(tool.invoke(user_id="u", tool_parameters={})) + + # Assert + assert len(empty_query) == 1 + assert empty_query[0].message.text == "please input query" + + +def test_query_invocation_result() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = list(tool.invoke(user_id="u", tool_parameters={"query": "hello"})) + + # Assert + assert len(result) == 1 + assert result[0].message.text == "result:hello" + + +def test_validate_credentials() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = tool.validate_credentials(credentials={}, parameters={}, format_only=False) + + # Assert + assert result is None diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool.py b/api/tests/unit_tests/core/tools/test_mcp_tool.py new file mode 100644 index 0000000000..eaf054de59 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import base64 +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.mcp.types import ( + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + TextResourceContents, +) +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_tool(*, with_output_schema: bool = True) -> MCPTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="remote-tool", + label=I18nObject(en_US="remote-tool"), + provider="provider-id", + ), + parameters=[], + output_schema={"type": "object"} if with_output_schema else {}, + ) + return MCPTool( + entity=entity, + runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER), + tenant_id="tenant-1", + icon="icon.svg", + server_url="https://mcp.example.com", + provider_id="provider-id", + headers={"x-auth": "token"}, + ) + + +def test_mcp_tool_provider_type_and_fork_runtime(): + tool = _build_mcp_tool() + assert tool.tool_provider_type() == ToolProviderType.MCP + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, MCPTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.provider_id == "provider-id" + + +def test_mcp_tool_text_and_json_processing_helpers(): + tool = _build_mcp_tool() + + json_messages = list(tool._process_text_content(TextContent(type="text", text='{"a": 1}'))) + assert json_messages[0].type == ToolInvokeMessage.MessageType.JSON + + plain_messages = list(tool._process_text_content(TextContent(type="text", text="not-json"))) + assert plain_messages[0].type == ToolInvokeMessage.MessageType.TEXT + assert plain_messages[0].message.text == "not-json" + + list_messages = list(tool._process_json_content([{"k": 1}, {"k": 2}])) + assert [m.type for m in list_messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.JSON] + + mixed_list_messages = list(tool._process_json_list([{"k": 1}, 2])) + assert len(mixed_list_messages) == 1 + assert mixed_list_messages[0].type == ToolInvokeMessage.MessageType.TEXT + + primitive_messages = list(tool._process_json_content(123)) + assert primitive_messages[0].message.text == "123" + + +def test_mcp_tool_usage_extraction_helpers(): + usage = MCPTool._extract_usage_dict({"usage": {"total_tokens": 9}}) + assert usage == {"total_tokens": 9} + + usage = MCPTool._extract_usage_dict({"metadata": {"usage": {"prompt_tokens": 3, "completion_tokens": 2}}}) + assert usage == {"prompt_tokens": 3, "completion_tokens": 2} + + usage = MCPTool._extract_usage_dict({"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}) + assert usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} + + usage = MCPTool._extract_usage_dict({"nested": [{"deep": {"usage": {"total_tokens": 7}}}]}) + assert usage == {"total_tokens": 7} + + result_with_usage = CallToolResult(content=[], _meta={"usage": {"prompt_tokens": 1, "completion_tokens": 2}}) + derived = MCPTool._derive_usage_from_result(result_with_usage) + assert derived.prompt_tokens == 1 + assert derived.completion_tokens == 2 + + result_without_usage = CallToolResult(content=[], _meta=None) + derived = MCPTool._derive_usage_from_result(result_without_usage) + assert derived.total_tokens == 0 + + +def test_mcp_tool_invoke_handles_content_types_and_structured_output(): + tool = _build_mcp_tool() + img_data = base64.b64encode(b"img").decode() + blob_data = base64.b64encode(b"blob").decode() + result = CallToolResult( + content=[ + TextContent(type="text", text='{"a": 1}'), + ImageContent(type="image", data=img_data, mimeType="image/png"), + EmbeddedResource( + type="resource", + resource=TextResourceContents(uri="file:///tmp/a.txt", text="embedded-text"), + ), + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="file:///tmp/b.bin", + blob=blob_data, + mimeType="application/octet-stream", + ), + ), + ], + structuredContent={"x": 1}, + _meta={"usage": {"prompt_tokens": 2, "completion_tokens": 3}}, + ) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"a": 1})) + + types = [m.type for m in messages] + assert ToolInvokeMessage.MessageType.JSON in types + assert ToolInvokeMessage.MessageType.BLOB in types + assert ToolInvokeMessage.MessageType.TEXT in types + assert ToolInvokeMessage.MessageType.VARIABLE in types + assert tool.latest_usage.total_tokens == 5 + + +def test_mcp_tool_invoke_raises_for_unsupported_embedded_resource(): + tool = _build_mcp_tool() + # Use model_construct to bypass pydantic validation and force unsupported resource path. + bad_resource = EmbeddedResource.model_construct(type="resource", resource=object()) + result = CallToolResult(content=[bad_resource], _meta=None) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + with pytest.raises(ToolInvokeError, match="Unsupported embedded resource type"): + list(tool.invoke(user_id="user-1", tool_parameters={})) + + +def test_mcp_tool_handle_none_parameter_filters_empty_values(): + tool = _build_mcp_tool() + cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"}) + assert cleaned == {"a": 1, "e": "ok"} diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py new file mode 100644 index 0000000000..1060d19ab1 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_entity(*, icon: str = "icon.svg") -> MCPProviderEntity: + now = datetime.now() + return MCPProviderEntity( + id="db-id", + provider_id="provider-id", + name="MCP Provider", + tenant_id="tenant-1", + user_id="user-1", + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[ + { + "name": "remote-tool", + "description": "remote tool", + "inputSchema": {}, + "outputSchema": {"type": "object"}, + } + ], + icon=icon, + created_at=now, + updated_at=now, + ) + + +def test_mcp_tool_provider_controller_from_entity_and_get_tools(): + entity = _build_mcp_entity() + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_entity(entity) + + assert controller.provider_type == ToolProviderType.MCP + tool = controller.get_tool("remote-tool") + assert isinstance(tool, MCPTool) + assert tool.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], MCPTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_mcp_tool_provider_controller_from_entity_requires_icon(): + entity = _build_mcp_entity(icon="") + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + with pytest.raises(ValueError, match="icon is required"): + MCPToolProviderController.from_entity(entity) + + +def test_mcp_tool_provider_controller_from_db_delegates_to_entity(): + entity = _build_mcp_entity() + db_provider = Mock() + db_provider.to_entity.return_value = entity + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_db(db_provider) + assert isinstance(controller, MCPToolProviderController) diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool.py b/api/tests/unit_tests/core/tools/test_plugin_tool.py new file mode 100644 index 0000000000..4378432a0f --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolParameter +from core.tools.plugin_tool.tool import PluginTool + + +def _build_plugin_tool(*, has_runtime_parameters: bool) -> PluginTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ], + has_runtime_parameters=has_runtime_parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, credentials={"api_key": "x"}) + return PluginTool( + entity=entity, + runtime=runtime, + tenant_id="tenant-1", + icon="icon.svg", + plugin_unique_identifier="plugin-uid", + ) + + +def test_plugin_tool_invoke_and_fork_runtime(): + tool = _build_plugin_tool(has_runtime_parameters=False) + manager = Mock() + manager.invoke.return_value = iter([tool.create_text_message("ok")]) + + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + with patch( + "core.tools.plugin_tool.tool.convert_parameters_to_plugin_format", + return_value={"converted": 1}, + ): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"raw": 1})) + + assert [m.message.text for m in messages] == ["ok"] + manager.invoke.assert_called_once() + assert manager.invoke.call_args.kwargs["tool_parameters"] == {"converted": 1} + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, PluginTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.plugin_unique_identifier == "plugin-uid" + + +def test_plugin_tool_get_runtime_parameters_branches(): + tool = _build_plugin_tool(has_runtime_parameters=False) + assert tool.get_runtime_parameters() == tool.entity.parameters + + tool = _build_plugin_tool(has_runtime_parameters=True) + cached = [ + ToolParameter.get_simple_instance( + name="k", + llm_description="k", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + tool.runtime_parameters = cached + assert tool.get_runtime_parameters() == cached + + tool.runtime_parameters = None + manager = Mock() + returned = [ + ToolParameter.get_simple_instance( + name="dyn", + llm_description="dyn", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + manager.get_runtime_parameters.return_value = returned + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + assert tool.get_runtime_parameters(conversation_id="c1", app_id="a1", message_id="m1") == returned + assert tool.runtime_parameters == returned diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py new file mode 100644 index 0000000000..5ef03cc6ca --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolProviderEntityWithPlugin, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool + + +def _build_controller() -> PluginToolProviderController: + tool_entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + entity = ToolProviderEntityWithPlugin( + identity=ToolProviderIdentity( + author="author", + name="provider-a", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ), + credentials_schema=[], + plugin_id="plugin-id", + tools=[tool_entity], + ) + return PluginToolProviderController( + entity=entity, + plugin_id="plugin-id", + plugin_unique_identifier="plugin-uid", + tenant_id="tenant-1", + ) + + +def test_plugin_tool_provider_controller_basic_behaviors(): + controller = _build_controller() + assert controller.provider_type == ToolProviderType.PLUGIN + + tool = controller.get_tool("tool-a") + assert isinstance(tool, PluginTool) + assert tool.runtime.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], PluginTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_validate_credentials_success(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = True + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) + + manager.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant-1", + user_id="u1", + provider="provider-a", + credentials={"api_key": "x"}, + ) + + +def test_validate_credentials_failure(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = False + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py new file mode 100644 index 0000000000..a5242a78c5 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -0,0 +1,119 @@ +"""Unit tests for core.tools.signature covering signing and verification invariants.""" + +from __future__ import annotations + +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature + + +def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=False) + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert parsed.scheme == "https" + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is True + + +def test_sign_tool_file_for_external_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x04" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=True) + parsed = urlparse(url) + + assert parsed.scheme == "https" + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + + +def test_verify_tool_file_signature_rejects_invalid_sign(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, "bad-signature") is False + + +def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + monkeypatch.setattr("core.tools.signature.time.time", lambda: int(timestamp) + 99) + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False + + +def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] + + +def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py new file mode 100644 index 0000000000..40c107667c --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ( + ToolEngineInvokeError, + ToolInvokeError, + ToolParameterValidationError, +) +from core.tools.tool_engine import ToolEngine + + +class _DummyTool(Tool): + result: Any + raise_error: Exception | None + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime): + super().__init__(entity=entity, runtime=runtime) + self.result = [self.create_text_message("ok")] + self.raise_error = None + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + if self.raise_error: + raise self.raise_error + if isinstance(self.result, list | Generator): + yield from self.result + else: + yield self.result + + +def _build_tool(with_llm_parameter: bool = False) -> _DummyTool: + parameters = [] + if with_llm_parameter: + parameters = [ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, runtime_parameters={"rt": 1}) + return _DummyTool(entity=entity, runtime=runtime) + + +def test_convert_tool_response_to_str_and_extract_binary_messages(): + tool = _build_tool() + messages = [ + tool.create_text_message("hello"), + tool.create_link_message("https://example.com"), + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="https://example.com/a.png"), + meta={"mime_type": "image/png"}, + ), + tool.create_json_message({"a": 1}), + tool.create_json_message({"a": 1}, suppress_output=True), + ] + text = ToolEngine._convert_tool_response_to_str(messages) + assert "hello" in text + assert "result link: https://example.com." in text + assert '"a": 1' in text + + blob_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.TextMessage(text="https://example.com/blob.bin"), + meta={"mime_type": "application/octet-stream"}, + ) + link_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://example.com/file.pdf"), + meta={"mime_type": "application/pdf"}, + ) + binaries = list(ToolEngine._extract_tool_response_binary_and_text([messages[2], blob_message, link_message])) + assert [b.mimetype for b in binaries] == ["image/png", "application/octet-stream", "application/pdf"] + + with pytest.raises(ValueError, match="missing meta data"): + list( + ToolEngine._extract_tool_response_binary_and_text( + [ + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="x"), + ) + ] + ) + ) + + +def test_create_message_files_and_invoke_generator(): + binaries = [ + ToolInvokeMessageBinary(mimetype="image/png", url="https://example.com/abc.png"), + ToolInvokeMessageBinary(mimetype="audio/wav", url="https://example.com/def.wav"), + ] + created = [] + + def _message_file_factory(**kwargs): + obj = SimpleNamespace(id=f"mf-{len(created) + 1}", **kwargs) + created.append(obj) + return obj + + with patch("core.tools.tool_engine.MessageFile", side_effect=_message_file_factory): + with patch("core.tools.tool_engine.db") as mock_db: + ids = ToolEngine._create_message_files( + tool_messages=binaries, + agent_message=SimpleNamespace(id="msg-1"), + invoke_from=InvokeFrom.DEBUGGER, + user_id="user-1", + ) + + assert ids == ["mf-1", "mf-2"] + assert mock_db.session.add.call_count == 2 + mock_db.session.close.assert_called_once() + + tool = _build_tool() + invoked = list(ToolEngine._invoke(tool, {"a": 1}, user_id="u")) + assert invoked[0].type == ToolInvokeMessage.MessageType.TEXT + assert isinstance(invoked[-1], ToolInvokeMeta) + assert invoked[-1].error is None + + +def test_generic_invoke_success_and_error_paths(): + tool = _build_tool() + callback = Mock() + callback.on_tool_execution.side_effect = lambda **kwargs: kwargs["tool_outputs"] + response = list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=callback, + workflow_call_depth=0, + conversation_id="c1", + app_id="a1", + message_id="m1", + ) + ) + assert response[0].message.text == "ok" + callback.on_tool_start.assert_called_once() + callback.on_tool_execution.assert_called_once() + + tool.raise_error = RuntimeError("boom") + error_callback = Mock() + error_callback.on_tool_execution.side_effect = lambda **kwargs: list(kwargs["tool_outputs"]) + with pytest.raises(RuntimeError, match="boom"): + list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=error_callback, + workflow_call_depth=0, + ) + ) + error_callback.on_tool_error.assert_called_once() + + +def test_agent_invoke_success(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + meta = ToolInvokeMeta.empty() + + with patch.object(ToolEngine, "_invoke", return_value=iter([tool.create_text_message("ok"), meta])): + with patch( + "core.tools.tool_engine.ToolFileMessageTransformer.transform_tool_invoke_messages", + side_effect=lambda messages, **kwargs: messages, + ): + with patch.object(ToolEngine, "_extract_tool_response_binary_and_text", return_value=iter([])): + with patch.object(ToolEngine, "_create_message_files", return_value=[]): + result_text, message_files, result_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters="hello", + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert result_text == "ok" + assert message_files == [] + assert result_meta.error is None + callback.on_tool_start.assert_called_once() + callback.on_tool_end.assert_called_once() + + +def test_agent_invoke_param_validation_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolParameterValidationError("bad-param")): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool parameters validation error" in error_text + assert files == [] + assert error_meta.error + + +def test_agent_invoke_engine_meta_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + engine_error = ToolEngineInvokeError(ToolInvokeMeta.error_instance("meta failure")) + + with patch.object(ToolEngine, "_invoke", side_effect=engine_error): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "meta failure" in error_text + assert files == [] + assert error_meta.error == "meta failure" + + +def test_agent_invoke_tool_invoke_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolInvokeError("invoke boom")): + error_text, files, _ = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool invoke error" in error_text + assert files == [] diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py new file mode 100644 index 0000000000..cca8254dd6 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -0,0 +1,249 @@ +"""Unit tests for `ToolFileManager` behavior. + +Covers signing/verification, file persistence flows, and retrieval APIs with +mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to +avoid real IO. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest + +from core.tools.tool_file_manager import ToolFileManager + + +def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100) + + url = ToolFileManager.sign_file("tf-1", ".png") + return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&")) + + +def _patch_session_factory(session: Mock): + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm) + + +def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + url = ToolFileManager.sign_file("tf-1", ".png") + assert "/files/tools/tf-1.png" in url + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True + + +def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False + + +def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0) + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False + + +def test_create_file_by_raw_stores_file_and_persists_record() -> None: + manager = ToolFileManager() + session = Mock() + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-1") + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.guess_extension", return_value=".txt"), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="abc")), + _patch_session_factory(session), + ): + file_model = manager.create_file_by_raw( + user_id="u1", + tenant_id="t1", + conversation_id="c1", + file_binary=b"hello", + mimetype="text/plain", + filename="readme", + ) + + assert file_model.name.endswith(".txt") + storage.save.assert_called_once() + session.add.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_downloads_and_persists_record() -> None: + manager = ToolFileManager() + response = Mock() + response.content = b"binary" + response.headers = {"Content-Type": "application/octet-stream"} + response.raise_for_status.return_value = None + session = Mock() + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-2") + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")), + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response), + ): + file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + assert file_model.file_key.startswith("tools/t1/") + storage.save.assert_called_once() + session.add.assert_called_once_with(file_model) + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_raises_on_timeout() -> None: + manager = ToolFileManager() + + with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")): + with pytest.raises(ValueError, match="timeout when downloading file"): + manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + +def test_get_file_binary_returns_none_when_not_found() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary("missing") + + # Assert + assert result is None + + +def test_get_file_binary_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"hello" + with _patch_session_factory(session): + result = manager.get_file_binary("id1") + + # Assert + assert result == (b"hello", "text/plain") + + +def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = None + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url=None) + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = tool_file + session.query.side_effect = [first_query, second_query] + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"img" + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result == (b"img", "image/png") + + +def test_get_file_generator_returns_none_when_toolfile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + stream, tool_file = manager.get_file_generator_by_tool_file_id("tool123") + + # Assert + assert stream is None + assert tool_file is None + + +def test_get_file_generator_returns_stream_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + stream = iter([b"a", b"b"]) + storage.load_stream.return_value = stream + with ( + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), + ): + result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") + assert list(result_stream) == [b"a", b"b"] + assert result_file == "validated-file" diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py new file mode 100644 index 0000000000..857f4aa178 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import PropertyMock, patch + +import pytest + +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + +class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + return None + + +def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: + controller = object.__new__(ApiToolProviderController) + controller.provider_id = provider_id + return controller + + +def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderController: + controller = object.__new__(WorkflowToolProviderController) + controller.provider_id = provider_id + return controller + + +def test_tool_label_manager_filter_tool_labels(): + filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) + assert set(filtered) == {"search", "news"} + assert len(filtered) == 2 + + +def test_tool_label_manager_update_tool_labels_db(): + controller = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + delete_query = mock_db.session.query.return_value.where.return_value + delete_query.delete.return_value = None + ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) + + delete_query.delete.assert_called_once() + # only one valid unique label should be inserted. + assert mock_db.session.add.call_count == 1 + mock_db.session.commit.assert_called_once() + + +def test_tool_label_manager_update_tool_labels_unsupported(): + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + with patch.object( + _ConcreteBuiltinToolProviderController, + "tool_labels", + new_callable=PropertyMock, + return_value=["search", "news"], + ): + builtin = object.__new__(_ConcreteBuiltinToolProviderController) + assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] + + api = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["search", "news"] + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tools_labels_batch(): + assert ToolLabelManager.get_tools_labels([]) == {} + + api = _api_controller("api-1") + wf = _workflow_controller("wf-1") + records = [ + SimpleNamespace(tool_id="api-1", label_name="search"), + SimpleNamespace(tool_id="api-1", label_name="news"), + SimpleNamespace(tool_id="wf-1", label_name="utilities"), + ] + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py new file mode 100644 index 0000000000..0f73e22654 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -0,0 +1,899 @@ +from __future__ import annotations + +"""Unit tests for ToolManager behavior with mocked providers and collaborators.""" + +import json +import threading +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.tool_manager import ToolManager + + +class _SimpleContextVar: + def __init__(self): + self._is_set = False + self._value: Any = None + + def get(self): + if not self._is_set: + raise LookupError + return self._value + + def set(self, value: Any): + self._value = value + self._is_set = True + + +def _cm(session: Any): + context = Mock() + context.__enter__ = Mock(return_value=session) + context.__exit__ = Mock(return_value=False) + return context + + +def _setup_list_providers_from_api_mocks( + monkeypatch, + *, + session: Mock, + hardcoded_controller: SimpleNamespace, + plugin_controller: PluginToolProviderController, + api_controller: SimpleNamespace, + workflow_controller: SimpleNamespace, +): + mock_db = Mock() + mock_db.engine = object() + monkeypatch.setattr("core.tools.tool_manager.db", mock_db) + monkeypatch.setattr("core.tools.tool_manager.Session", lambda *args, **kwargs: _cm(session)) + monkeypatch.setattr( + ToolManager, + "list_builtin_providers", + Mock(return_value=[hardcoded_controller, plugin_controller]), + ) + monkeypatch.setattr( + ToolManager, + "list_default_builtin_providers", + Mock(return_value=[SimpleNamespace(provider="hardcoded")]), + ) + monkeypatch.setattr("core.tools.tool_manager.is_filtered", lambda *args, **kwargs: False) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.builtin_provider_to_user_provider", + lambda **kwargs: SimpleNamespace(name=kwargs["provider_controller"].entity.identity.name), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_controller", + Mock(side_effect=[api_controller, RuntimeError("invalid")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="api-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + Mock(side_effect=[workflow_controller, RuntimeError("deleted app")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="workflow-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolLabelManager.get_tools_labels", + Mock(side_effect=[{"api-1": ["search"]}, {"wf-1": ["utility"]}]), + ) + mock_mcp_service = Mock() + mock_mcp_service.list_providers.return_value = [SimpleNamespace(name="mcp-provider")] + monkeypatch.setattr("core.tools.tool_manager.MCPToolManageService", Mock(return_value=mock_mcp_service)) + monkeypatch.setattr("core.tools.tool_manager.BuiltinToolProviderSort.sort", lambda providers: providers) + + +@pytest.fixture(autouse=True) +def _reset_tool_manager_state(): + old_hardcoded = ToolManager._hardcoded_providers.copy() + old_loaded = ToolManager._builtin_providers_loaded + old_labels = ToolManager._builtin_tools_labels.copy() + try: + yield + finally: + ToolManager._hardcoded_providers = old_hardcoded + ToolManager._builtin_providers_loaded = old_loaded + ToolManager._builtin_tools_labels = old_labels + + +def test_get_hardcoded_provider_loads_cache_when_empty(): + provider = Mock() + ToolManager._hardcoded_providers = {} + + def _load(): + ToolManager._hardcoded_providers["weather"] = provider + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load) as mock_load: + assert ToolManager.get_hardcoded_provider("weather") is provider + + mock_load.assert_called_once() + + +def test_get_builtin_provider_returns_plugin_for_missing_hardcoded(): + hardcoded = Mock() + plugin_provider = Mock() + ToolManager._hardcoded_providers = {"time": hardcoded} + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + assert ToolManager.get_builtin_provider("time", "tenant-1") is hardcoded + assert ToolManager.get_builtin_provider("plugin/time", "tenant-1") is plugin_provider + + +def test_get_plugin_provider_uses_context_cache(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + provider_entity = SimpleNamespace(declaration=Mock(), plugin_id="pid", plugin_unique_identifier="uid") + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = provider_entity + controller = SimpleNamespace(name="controller") + with patch("core.tools.tool_manager.PluginToolProviderController", return_value=controller): + built = ToolManager.get_plugin_provider("provider-a", "tenant-1") + cached = ToolManager.get_plugin_provider("provider-a", "tenant-1") + + assert built is controller + assert cached is controller + mock_manager_cls.return_value.fetch_tool_provider.assert_called_once() + + +def test_get_plugin_provider_raises_when_provider_missing(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="plugin provider provider-a not found"): + ToolManager.get_plugin_provider("provider-a", "tenant-1") + + +def test_get_tool_runtime_builtin_without_credentials(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace(get_tool=Mock(return_value=tool), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="current_time", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.tenant_id == "tenant-1" + assert runtime.credentials == {} + + +def test_get_tool_runtime_builtin_missing_tool_raises(): + controller = SimpleNamespace(get_tool=Mock(return_value=None), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with pytest.raises(ToolProviderNotFoundError, match="builtin tool missing not found"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="missing", + tenant_id="tenant-1", + ) + + +def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.API_KEY.value, + credentials={"encrypted": "value"}, + expires_at=-1, + user_id="user-1", + ) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with patch("core.helper.credential_utils.check_credential_policy_compliance"): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + builtin_provider + ) + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key": "secret"} + cache = Mock() + with patch("core.tools.tool_manager.create_provider_encrypter", return_value=(encrypter, cache)): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.credentials == {"api_key": "secret"} + assert runtime.credential_type == CredentialType.API_KEY + + +@patch("core.tools.tool_manager.create_provider_encrypter") +@patch("core.plugin.impl.oauth.OAuthHandler") +@patch( + "services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client", + return_value={"client_id": "id"}, +) +@patch("core.tools.tool_manager.db") +@patch("core.tools.tool_manager.time.time", return_value=1000) +@patch("core.helper.credential_utils.check_credential_policy_compliance") +def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( + mock_check, + mock_time, + mock_db, + mock_get_oauth_client, + mock_oauth_handler_cls, + mock_create_provider_encrypter, +): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"encrypted": "value"}, + encrypted_credentials=None, + expires_at=1, + user_id="user-1", + ) + refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) + + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + encrypter = Mock() + encrypter.decrypt.return_value = {"token": "old"} + encrypter.encrypt.return_value = {"token": "encrypted"} + cache = Mock() + mock_create_provider_encrypter.return_value = (encrypter, cache) + mock_oauth_handler_cls.return_value.refresh_credentials.return_value = refreshed + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + assert builtin_provider.expires_at == refreshed.expires_at + assert builtin_provider.encrypted_credentials == json.dumps({"token": "encrypted"}) + mock_db.session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_get_tool_runtime_builtin_plugin_provider_deleted_raises(): + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(credentials_schema=[{"name": "k"}], oauth_schema=None) + plugin_controller.get_tool = Mock(return_value=Mock()) + plugin_controller.get_credentials_schema_by_type = Mock(return_value=[]) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_controller): + with patch("core.tools.tool_manager.is_valid_uuid", return_value=True): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.scalar.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="provider has been deleted"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + credential_id="uuid-id", + ) + + +def test_get_tool_runtime_api_path(): + api_tool = Mock() + api_tool.fork_tool_runtime.return_value = "api-runtime" + api_provider = Mock() + api_provider.get_tool.return_value = api_tool + + with patch.object(ToolManager, "get_api_provider_controller", return_value=(api_provider, {"c": "enc"})): + encrypter = Mock() + encrypter.decrypt.return_value = {"c": "dec"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + ) + == "api-runtime" + ) + + +def test_get_tool_runtime_workflow_path(): + workflow_provider = SimpleNamespace(tenant_id="tenant-1") + workflow_tool = Mock() + workflow_tool.fork_tool_runtime.return_value = "wf-runtime" + workflow_controller = Mock() + workflow_controller.get_tools.return_value = [workflow_tool] + session = Mock() + session.begin.return_value = _cm(None) + session.scalar.return_value = workflow_provider + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + return_value=workflow_controller, + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.WORKFLOW, + provider_id="wf-1", + tool_name="wf", + tenant_id="tenant-1", + ) + == "wf-runtime" + ) + + +def test_get_tool_runtime_plugin_path(): + with patch.object( + ToolManager, + "get_plugin_provider", + return_value=SimpleNamespace(get_tool=lambda _: "plugin-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.PLUGIN, + provider_id="plugin-1", + tool_name="p", + tenant_id="tenant-1", + ) + == "plugin-tool" + ) + + +def test_get_tool_runtime_mcp_path(): + with patch.object( + ToolManager, + "get_mcp_provider_controller", + return_value=SimpleNamespace(get_tool=lambda _: "mcp-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.MCP, + provider_id="mcp-1", + tool_name="m", + tenant_id="tenant-1", + ) + == "mcp-tool" + ) + + +def test_get_tool_runtime_app_not_implemented(): + with pytest.raises(NotImplementedError, match="app provider not implemented"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.APP, + provider_id="app", + tool_name="x", + tenant_id="tenant-1", + ) + + +def test_get_agent_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={"query": "hello"}, + credential_id=None, + ) + result = ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert result is tool_runtime + assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + + +def test_get_workflow_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + workflow_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_configurations={"query": "hello"}, + credential_id=None, + ) + tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + workflow_result = ToolManager.get_workflow_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + workflow_tool=workflow_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert workflow_result is tool_runtime2 + assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + + +def test_get_agent_runtime_raises_when_runtime_missing(): + tool_runtime = SimpleNamespace(runtime=None, get_merged_runtime_parameters=lambda: []) + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={}, + credential_id=None, + ) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={}): + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=Mock()): + with pytest.raises(ValueError, match="runtime not found"): + ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + ) + + +def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): + form_param = ToolParameter.get_simple_instance( + name="q", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + form_param.form = ToolParameter.ToolParameterForm.FORM + llm_param = ToolParameter.get_simple_instance( + name="llm", + llm_description="llm", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + llm_param.form = ToolParameter.ToolParameterForm.LLM + + tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + result = ToolManager.get_tool_runtime_from_plugin( + tool_type=ToolProviderType.API, + tenant_id="tenant-1", + provider="api-1", + tool_name="search", + tool_parameters={"q": "hello", "llm": "ignore"}, + ) + + assert result is tool_entity + assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + + +def test_hardcoded_provider_icon_success(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=True): + with patch("core.tools.tool_manager.mimetypes.guess_type", return_value=("image/svg+xml", None)): + icon_path, mime = ToolManager.get_hardcoded_provider_icon("time") + assert icon_path.endswith("icon.svg") + assert mime == "image/svg+xml" + + +def test_hardcoded_provider_icon_missing_raises(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=False): + with pytest.raises(ToolProviderNotFoundError, match="icon not found"): + ToolManager.get_hardcoded_provider_icon("time") + + +def test_list_hardcoded_providers_cache_hit(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + assert list(ToolManager.list_hardcoded_providers()) == list(ToolManager._hardcoded_providers.values()) + + +def test_clear_hardcoded_providers_cache_resets(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + ToolManager.clear_hardcoded_providers_cache() + assert ToolManager._hardcoded_providers == {} + assert ToolManager._builtin_providers_loaded is False + + +def test_list_hardcoded_providers_internal_loader(): + good_provider = SimpleNamespace( + entity=SimpleNamespace(identity=SimpleNamespace(name="good")), + get_tools=lambda: [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="tool-a", label="A")))], + ) + provider_class = Mock(return_value=good_provider) + + with patch("core.tools.tool_manager.listdir", return_value=["good", "bad", "__skip"]): + with patch("core.tools.tool_manager.path.isdir", side_effect=lambda p: "good" in p or "bad" in p): + with patch( + "core.tools.tool_manager.load_single_subclass_from_source", + side_effect=[provider_class, RuntimeError("boom")], + ): + ToolManager._hardcoded_providers = {} + ToolManager._builtin_tools_labels = {} + providers = list(ToolManager._list_hardcoded_providers()) + + assert providers == [good_provider] + assert ToolManager._hardcoded_providers["good"] is good_provider + assert ToolManager._builtin_tools_labels["tool-a"] == "A" + assert ToolManager._builtin_providers_loaded is True + + +def test_get_tool_label_loads_cache_and_handles_missing(): + ToolManager._builtin_tools_labels = {} + + def _load(): + ToolManager._builtin_tools_labels["tool-a"] = "Label A" + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load): + assert ToolManager.get_tool_label("tool-a") == "Label A" + assert ToolManager.get_tool_label("missing") is None + + +def test_list_default_builtin_providers_for_postgres_and_mysql(): + provider_records = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + + for scheme in ("postgresql", "mysql"): + session = Mock() + session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + session.query.return_value.where.return_value.all.return_value = provider_records + + with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + providers = ToolManager.list_default_builtin_providers("tenant-1") + + assert providers == provider_records + + +def test_list_providers_from_api_covers_builtin_api_workflow_and_mcp(monkeypatch): + hardcoded_controller = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="hardcoded"))) + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(identity=SimpleNamespace(name="plugin-provider")) + + api_db_provider_good = SimpleNamespace(id="api-1") + api_db_provider_bad = SimpleNamespace(id="api-2") + api_controller = SimpleNamespace(provider_id="api-1") + + workflow_db_provider_good = SimpleNamespace(id="wf-1") + workflow_db_provider_bad = SimpleNamespace(id="wf-2") + workflow_controller = SimpleNamespace(provider_id="wf-1") + + session = Mock() + session.scalars.side_effect = [ + SimpleNamespace(all=lambda: [api_db_provider_good, api_db_provider_bad]), + SimpleNamespace(all=lambda: [workflow_db_provider_good, workflow_db_provider_bad]), + ] + + _setup_list_providers_from_api_mocks( + monkeypatch, + session=session, + hardcoded_controller=hardcoded_controller, + plugin_controller=plugin_controller, + api_controller=api_controller, + workflow_controller=workflow_controller, + ) + providers = ToolManager.list_providers_from_api(user_id="user-1", tenant_id="tenant-1", typ="") + + names = {provider.name for provider in providers} + assert {"hardcoded", "plugin-provider", "api-provider", "workflow-provider", "mcp-provider"} <= names + + +def test_get_api_provider_controller_returns_controller_and_credentials(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch( + "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller + ) as mock_from_db: + built_controller, credentials = ToolManager.get_api_provider_controller("tenant-1", "api-1") + + assert built_controller is controller + assert credentials == provider.credentials + mock_from_db.assert_called_with(provider, ApiProviderAuthType.API_KEY_QUERY) + controller.load_bundled_tools.assert_called_once_with(provider.tools) + + +def test_user_get_api_provider_masks_credentials_and_adds_labels(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key_value": "secret"} + encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + with patch("core.tools.tool_manager.ToolLabelManager.get_tool_labels", return_value=["search"]): + user_payload = ToolManager.user_get_api_provider("api-provider", "tenant-1") + + assert user_payload["credentials"]["api_key_value"] == "***" + assert user_payload["labels"] == ["search"] + + +def test_get_api_provider_controller_not_found_raises(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): + ToolManager.get_api_provider_controller("tenant-1", "missing") + + +def test_get_mcp_provider_controller_returns_controller(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + controller = Mock() + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider.return_value = provider_entity + with patch("core.tools.tool_manager.MCPToolProviderController.from_db", return_value=controller): + built = ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + assert built is controller + + +def test_generate_mcp_tool_icon_url_returns_provider_icon(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider_entity.return_value = provider_entity + assert ToolManager.generate_mcp_tool_icon_url("tenant-1", "mcp-1") == provider_entity.provider_icon + + +def test_get_mcp_provider_controller_missing_raises(): + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service_cls.return_value.get_provider.side_effect = ValueError("missing") + with pytest.raises(ToolProviderNotFoundError, match="mcp provider mcp-1 not found"): + ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + + +def test_generate_tool_icon_urls_for_builtin_and_plugin(): + with patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "https://console.example.com"): + builtin_url = ToolManager.generate_builtin_tool_icon_url("time") + plugin_url = ToolManager.generate_plugin_tool_icon_url("tenant-1", "icon.svg") + + assert builtin_url.endswith("/tool-provider/builtin/time/icon") + assert "/plugin/icon" in plugin_url + + +def test_generate_tool_icon_urls_for_workflow_and_api(): + workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') + api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} + assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} + + +def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + + +def test_get_tool_icon_for_builtin_provider_variants(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", return_value="plugin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "plugin-provider") == "plugin-icon" + + with patch.object(ToolManager, "get_builtin_provider", return_value=SimpleNamespace()): + with patch.object(ToolManager, "generate_builtin_tool_icon_url", return_value="builtin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "time") == "builtin-icon" + + +def test_get_tool_icon_for_api_workflow_and_mcp(): + with patch.object(ToolManager, "generate_api_tool_icon_url", return_value={"background": "#000"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.API, "api-1") == {"background": "#000"} + + with patch.object(ToolManager, "generate_workflow_tool_icon_url", return_value={"background": "#111"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.WORKFLOW, "wf-1") == {"background": "#111"} + + with patch.object(ToolManager, "generate_mcp_tool_icon_url", return_value={"background": "#222"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.MCP, "mcp-1") == {"background": "#222"} + + +def test_get_tool_icon_plugin_error_returns_default(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", side_effect=RuntimeError("fail")): + icon = ToolManager.get_tool_icon("tenant-1", ToolProviderType.PLUGIN, "plugin-provider") + assert icon["background"] == "#252525" + + +def test_get_tool_icon_invalid_provider_type_raises(): + with pytest.raises(ValueError, match="provider type"): + ToolManager.get_tool_icon("tenant-1", "invalid", "x") # type: ignore[arg-type] + + +def test_convert_tool_parameters_type_agent_and_workflow_branches(): + file_param = ToolParameter.get_simple_instance( + name="file", + llm_description="file", + typ=ToolParameter.ToolParameterType.FILE, + required=True, + ) + file_param.form = ToolParameter.ToolParameterForm.FORM + + with pytest.raises(ValueError, match="file type parameter file not supported in agent"): + ToolManager._convert_tool_parameters_type( + parameters=[file_param], + variable_pool=None, + tool_configurations={"file": "x"}, + typ="agent", + ) + + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + plain = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=None, + tool_configurations={"text": "hello"}, + typ="workflow", + ) + assert plain == {"text": "hello"} + + variable_pool = Mock() + variable_pool.get.return_value = SimpleNamespace(value="from-variable") + variable_pool.convert_template.return_value = SimpleNamespace(text="from-template") + + mixed = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "mixed", "value": "Hello {{name}}"}}, + typ="workflow", + ) + assert mixed == {"text": "from-template"} + + variable = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "variable", "value": ["sys", "query"]}}, + typ="workflow", + ) + assert variable == {"text": "from-variable"} + + +def test_convert_tool_parameters_type_constant_branch(): + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + variable_pool = Mock() + + constant = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "constant", "value": "fixed"}}, + typ="workflow", + ) + + assert constant == {"text": "fixed"} diff --git a/api/tests/unit_tests/core/tools/test_tool_provider_controller.py b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py new file mode 100644 index 0000000000..30b8494c92 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest + +from core.entities.provider_entities import ProviderConfig +from core.tools.__base.tool import Tool +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError + + +class _DummyTool(Tool): + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _DummyController(ToolProviderController): + def get_tool(self, tool_name: str) -> Tool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name=tool_name, + label=I18nObject(en_US=tool_name), + provider="provider", + ), + parameters=[], + ) + return _DummyTool(entity=entity, runtime=ToolRuntime(tenant_id="tenant")) + + +def _provider_identity() -> ToolProviderIdentity: + return ToolProviderIdentity( + author="author", + name="provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ) + + +def test_tool_provider_controller_get_credentials_schema_returns_deep_copy(): + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="api_key", required=False)], + ) + controller = _DummyController(entity=entity) + + schema = controller.get_credentials_schema() + schema[0].name = "changed" + + assert controller.entity.credentials_schema[0].name == "api_key" + + +def test_tool_provider_controller_default_provider_type(): + entity = ToolProviderEntity(identity=_provider_identity(), credentials_schema=[]) + controller = _DummyController(entity=entity) + + assert controller.provider_type == ToolProviderType.BUILT_IN + + +def test_validate_credentials_format_covers_required_default_and_type_rules(): + select_options = [ProviderConfig.Option(value="opt-a", label=I18nObject(en_US="A"))] + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="required_text", required=True), + ProviderConfig(type=ProviderConfig.Type.SECRET_INPUT, name="secret", required=False), + ProviderConfig(type=ProviderConfig.Type.SELECT, name="choice", required=False, options=select_options), + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="with_default", required=False, default="x"), + ], + ) + controller = _DummyController(entity=entity) + + credentials = {"required_text": "value", "secret": None, "choice": "opt-a"} + controller.validate_credentials_format(credentials) + assert credentials["with_default"] == "x" + + with pytest.raises(ToolProviderCredentialValidationError, match="not found"): + controller.validate_credentials_format({"required_text": "value", "unknown": "v"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="is required"): + controller.validate_credentials_format({"secret": "s"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="should be string"): + controller.validate_credentials_format({"required_text": 123}) # type: ignore[arg-type] + + with pytest.raises(ToolProviderCredentialValidationError, match="should be one of"): + controller.validate_credentials_format({"required_text": "value", "choice": "opt-b"}) diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py new file mode 100644 index 0000000000..4d59affb99 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.utils.configuration import ToolParameterConfigurationManager + + +class _DummyTool(Tool): + runtime_overrides: list[ToolParameter] + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, runtime_overrides: list[ToolParameter]): + super().__init__(entity=entity, runtime=runtime) + self.runtime_overrides = runtime_overrides + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + def get_runtime_parameters( + self, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> list[ToolParameter]: + return self.runtime_overrides + + +def _param( + name: str, + *, + typ: ToolParameter.ToolParameterType, + form: ToolParameter.ToolParameterForm, + required: bool = False, +) -> ToolParameter: + return ToolParameter( + name=name, + label=I18nObject(en_US=name), + placeholder=I18nObject(en_US=""), + human_description=I18nObject(en_US=""), + type=typ, + form=form, + required=required, + default=None, + ) + + +def _build_manager() -> ToolParameterConfigurationManager: + base_params = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("plain", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + runtime_overrides = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("runtime_only", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + entity = ToolEntity( + identity=ToolIdentity(author="a", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=base_params, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + tool = _DummyTool(entity=entity, runtime=runtime, runtime_overrides=runtime_overrides) + return ToolParameterConfigurationManager( + tenant_id="tenant-1", + tool_runtime=tool, + provider_name="provider-a", + provider_type=ToolProviderType.BUILT_IN, + identity_id="ID.1", + ) + + +def test_merge_and_mask_parameters(): + manager = _build_manager() + + masked = manager.mask_tool_parameters({"secret": "abcdefghi", "plain": "x", "runtime_only": "y"}) + assert masked["secret"] == "ab*****hi" + assert masked["plain"] == "x" + assert masked["runtime_only"] == "y" + + +def test_encrypt_tool_parameters(): + manager = _build_manager() + + with patch("core.tools.utils.configuration.encrypter.encrypt_token", return_value="enc"): + encrypted = manager.encrypt_tool_parameters({"secret": "raw", "plain": "x"}) + + assert encrypted["secret"] == "enc" + assert encrypted["plain"] == "x" + + +def test_decrypt_tool_parameters_cache_hit_and_miss(): + manager = _build_manager() + + with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + cache = cache_cls.return_value + cache.get.return_value = {"secret": "cached"} + assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"} + cache.set.assert_not_called() + + with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + cache = cache_cls.return_value + cache.get.return_value = None + with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"): + decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) + + assert decrypted["secret"] == "dec" + cache.set.assert_called_once() + + +def test_delete_tool_parameters_cache(): + manager = _build_manager() + + with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + manager.delete_tool_parameters_cache() + + cache_cls.return_value.delete.assert_called_once() + + +def test_configuration_manager_decrypt_suppresses_errors(): + manager = _build_manager() + with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + cache = cache_cls.return_value + cache.get.return_value = None + with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): + decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) + # decryption failure is suppressed, original value is retained. + assert decrypted["secret"] == "enc" diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py index 94be0bb573..ce77473dbd 100644 --- a/api/tests/unit_tests/core/tools/utils/test_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -1,10 +1,13 @@ import copy -from unittest.mock import patch +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch import pytest from core.entities.provider_entities import BasicProviderConfig from core.helper.provider_encryption import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter # --------------------------- @@ -13,13 +16,13 @@ from core.helper.provider_encryption import ProviderConfigEncrypter class NoopCache: """Simple cache stub: always returns None, does nothing for set/delete.""" - def get(self): + def get(self) -> Any | None: return None - def set(self, config): + def set(self, config: Any) -> None: pass - def delete(self): + def delete(self) -> None: pass @@ -179,3 +182,35 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): out = encrypter_obj.decrypt({"password": "ENC_ERR"}) assert out["password"] == "ENC_ERR" + + +def test_create_tool_provider_encrypter_builds_cache_and_encrypter(): + basic_config = BasicProviderConfig(name="key", type=BasicProviderConfig.Type.TEXT_INPUT) + credential_schema_item = SimpleNamespace(to_basic_provider_config=lambda: basic_config) + controller = SimpleNamespace( + provider_type=SimpleNamespace(value="builtin"), + entity=SimpleNamespace(identity=SimpleNamespace(name="provider-a")), + get_credentials_schema=lambda: [credential_schema_item], + ) + + cache_instance = Mock() + encrypter_instance = Mock() + + with patch( + "core.tools.utils.encryption.SingletonProviderCredentialsCache", return_value=cache_instance + ) as cache_cls: + with patch("core.tools.utils.encryption.ProviderConfigEncrypter", return_value=encrypter_instance) as enc_cls: + encrypter, cache = create_tool_provider_encrypter("tenant-1", controller) + + assert encrypter is encrypter_instance + assert cache is cache_instance + cache_cls.assert_called_once_with( + tenant_id="tenant-1", + provider_type="builtin", + provider_identity="provider-a", + ) + enc_cls.assert_called_once_with( + tenant_id="tenant-1", + config=[basic_config], + provider_config_cache=cache_instance, + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py new file mode 100644 index 0000000000..4ce73272bf --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import uuid +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +from yaml import YAMLError + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.models.document import Document as RagDocument +from core.tools.utils.dataset_retriever import dataset_multi_retriever_tool as multi_retriever_module +from core.tools.utils.dataset_retriever import dataset_retriever_tool as single_retriever_module +from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool as SingleDatasetRetrieverTool +from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.tools.utils.uuid_utils import is_valid_uuid +from core.tools.utils.yaml_utils import _load_yaml_file, load_yaml_file_cached + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE) + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None, **_kwargs): + self._target = target + self._kwargs = kwargs or {} + + def start(self): + if self._target is not None: + self._target(**self._kwargs) + + def join(self): + return None + + +class _TestHitCallback(DatasetIndexToolCallbackHandler): + def __init__(self): + self.queries: list[tuple[str, str]] = [] + self.documents: list[RagDocument] | None = None + self.resources = None + + def on_query(self, query: str, dataset_id: str): + self.queries.append((query, dataset_id)) + + def on_tool_end(self, documents: list[RagDocument]): + self.documents = documents + + def return_retriever_resource_info(self, resource): + self.resources = list(resource) + + +def test_remove_leading_symbols_preserves_markdown_link_and_strips_punctuation(): + markdown = "[Example](https://example.com) content" + assert remove_leading_symbols(markdown) == markdown + + assert remove_leading_symbols("...Hello world") == "Hello world" + + +def test_is_valid_uuid_handles_valid_invalid_and_empty_values(): + assert is_valid_uuid(str(uuid.uuid4())) is True + assert is_valid_uuid("not-a-uuid") is False + assert is_valid_uuid("") is False + assert is_valid_uuid(None) is False + + +def test_load_yaml_file_valid(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + loaded = _load_yaml_file(file_path=str(valid_file)) + + assert loaded == {"a": 1, "b": "two"} + + +def test_load_yaml_file_missing(tmp_path): + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=str(tmp_path / "missing.yaml")) + + +def test_load_yaml_file_invalid(tmp_path): + invalid_file = tmp_path / "invalid.yaml" + invalid_file.write_text("a: [1, 2\n", encoding="utf-8") + + with pytest.raises(YAMLError): + _load_yaml_file(file_path=str(invalid_file)) + + +def test_load_yaml_file_cached_hits(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + load_yaml_file_cached.cache_clear() + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + assert load_yaml_file_cached.cache_info().hits == 1 + + +def test_single_dataset_retriever_from_dataset_builds_name_and_description(): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="Knowledge", description=None) + + tool = SingleDatasetRetrieverTool.from_dataset( + dataset=dataset, + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + inputs={}, + ) + + assert tool.name == "dataset_dataset_1" + assert tool.description == "useful for when you want to answer queries about the Knowledge" + + +def test_single_dataset_retriever_external_run_returns_content_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="external", + indexing_technique="high_quality", + retrieval_model={}, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ( + {"dataset-1": ["doc-a"]}, + {"logical_operator": "and"}, + ) + db_session = Mock() + db_session.scalar.return_value = dataset + external_documents = [ + {"content": "first", "metadata": {"document_id": "doc-a"}, "score": 0.9, "title": "Doc A"}, + {"content": "second", "metadata": {"document_id": "doc-b"}, "score": 0.8, "title": "Doc B"}, + ] + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={"x": 1}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object( + single_retriever_module.ExternalDatasetService, + "fetch_external_knowledge_retrieval", + return_value=external_documents, + ) as fetch_mock: + result = tool.run(query="hello") + + assert result == "first\nsecond" + assert callback.queries == [("hello", "dataset-1")] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].dataset_id == "dataset-1" + fetch_mock.assert_called_once() + + +def test_single_dataset_retriever_returns_empty_when_metadata_filter_finds_no_documents(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model=None, + ) + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ({"dataset-1": []}, {"logical_operator": "and"}) + db_session = Mock() + db_session.scalar.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + hit_callbacks=[_TestHitCallback()], + inputs={}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool.run(query="hello") + + assert result == "" + retrieve_mock.assert_not_called() + + +def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "score_threshold_enabled": True, + "score_threshold": 0.2, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "provider", "reranking_model_name": "model"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {"vector_weight": 0.6}}, + }, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = (None, None) + low_segment = SimpleNamespace( + id="seg-low", + dataset_id="dataset-1", + document_id="doc-low", + content="raw low", + answer="low answer", + hit_count=1, + word_count=10, + position=3, + index_node_hash="hash-low", + get_sign_content=lambda: "signed low", + ) + high_segment = SimpleNamespace( + id="seg-high", + dataset_id="dataset-1", + document_id="doc-high", + content="raw high", + answer=None, + hit_count=9, + word_count=25, + position=1, + index_node_hash="hash-high", + get_sign_content=lambda: "signed high", + ) + records = [ + SimpleNamespace(segment=low_segment, score=0.2, summary="summary low"), + SimpleNamespace(segment=high_segment, score=0.9, summary=None), + ] + documents = [ + RagDocument(page_content="first", metadata={"doc_id": "node-low", "score": 0.2}), + RagDocument(page_content="second", metadata={"doc_id": "node-high", "score": 0.9}), + ] + lookup_doc_low = SimpleNamespace( + id="doc-low", name="Document Low", data_source_type="upload_file", doc_metadata={"lang": "en"} + ) + lookup_doc_high = SimpleNamespace( + id="doc-high", name="Document High", data_source_type="notion", doc_metadata={"lang": "fr"} + ) + db_session = Mock() + db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] + db_session.query.return_value.filter_by.return_value.first.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={}, + top_k=2, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve", return_value=documents): + with patch.object( + single_retriever_module.RetrievalService, + "format_retrieval_documents", + return_value=records, + ): + result = tool.run(query="hello") + + assert result == "signed high\nsummary low\nquestion:signed low answer:low answer" + assert callback.documents == documents + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].segment_id == "seg-high" + assert resource_info[0].hit_count == 9 + assert resource_info[1].summary == "summary low" + assert resource_info[1].content == "question:raw low \nanswer:low answer" + + +def test_multi_dataset_retriever_from_dataset_sets_tool_name(): + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=["dataset-1"], + tenant_id="tenant-1", + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + assert tool.name == "dataset_tenant_1" + + +def test_multi_dataset_retriever_retriever_returns_early_when_dataset_is_missing(): + callback = _TestHitCallback() + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = None + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert result == [] + assert all_documents == [] + assert callback.queries == [] + retrieve_mock.assert_not_called() + + +def test_multi_dataset_retriever_retriever_non_economy_uses_retrieval_model(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 6, + "score_threshold_enabled": True, + "score_threshold": 0.4, + "reranking_enable": False, + "reranking_mode": None, + "weights": {"balanced": True}, + }, + ) + callback = _TestHitCallback() + documents = [RagDocument(page_content="retrieved", metadata={"doc_id": "node-1", "score": 0.4})] + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = dataset + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + top_k=2, + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve", return_value=documents) as retrieve_mock: + tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert all_documents == documents + assert callback.queries == [("hello", "dataset-1")] + retrieve_mock.assert_called_once_with( + retrieval_method="semantic_search", + dataset_id="dataset-1", + query="hello", + top_k=6, + score_threshold=0.4, + reranking_model=None, + reranking_mode="reranking_model", + weights={"balanced": True}, + ) + + +def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): + callback = _TestHitCallback() + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1", "dataset-2"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + top_k=2, + score_threshold=0.1, + ) + first_doc = RagDocument(page_content="first", metadata={"doc_id": "node-2", "score": 0.4}) + second_doc = RagDocument(page_content="second", metadata={"doc_id": "node-1", "score": 0.9}) + + def fake_retriever(**kwargs): + if kwargs["dataset_id"] == "dataset-1": + kwargs["all_documents"].append(first_doc) + else: + kwargs["all_documents"].append(second_doc) + + segment_for_node_2 = SimpleNamespace( + id="seg-2", + dataset_id="dataset-1", + document_id="doc-2", + index_node_id="node-2", + content="raw two", + answer="answer two", + hit_count=2, + word_count=20, + position=2, + index_node_hash="hash-2", + get_sign_content=lambda: "signed two", + ) + segment_for_node_1 = SimpleNamespace( + id="seg-1", + dataset_id="dataset-2", + document_id="doc-1", + index_node_id="node-1", + content="raw one", + answer=None, + hit_count=7, + word_count=30, + position=1, + index_node_hash="hash-1", + get_sign_content=lambda: "signed one", + ) + db_session = Mock() + db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] + db_session.query.return_value.filter_by.return_value.first.side_effect = [ + SimpleNamespace(id="dataset-2", name="Dataset Two"), + SimpleNamespace(id="dataset-1", name="Dataset One"), + ] + db_session.scalar.side_effect = [ + SimpleNamespace(id="doc-1", name="Doc One", data_source_type="upload_file", doc_metadata={"p": 1}), + SimpleNamespace(id="doc-2", name="Doc Two", data_source_type="notion", doc_metadata={"p": 2}), + ] + model_manager = Mock() + model_manager.get_model_instance.return_value = Mock() + rerank_runner = Mock() + rerank_runner.run.return_value = [second_doc, first_doc] + fake_current_app = SimpleNamespace(_get_current_object=lambda: _FakeFlaskApp()) + + with patch.object(tool, "_retriever", side_effect=fake_retriever) as retriever_mock: + with patch.object(multi_retriever_module, "current_app", fake_current_app): + with patch.object(multi_retriever_module.threading, "Thread", _ImmediateThread): + with patch.object(multi_retriever_module, "ModelManager", return_value=model_manager): + with patch.object(multi_retriever_module, "RerankModelRunner", return_value=rerank_runner): + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + result = tool.run(query="hello") + + assert result == "signed one\nquestion:signed two answer:answer two" + assert retriever_mock.call_count == 2 + assert callback.documents == [second_doc, first_doc] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].score == 0.9 + assert resource_info[0].content == "raw one" + assert resource_info[1].score == 0.4 + assert resource_info[1].content == "question:raw two \nanswer:answer two" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py new file mode 100644 index 0000000000..2acae889b2 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -0,0 +1,158 @@ +"""Unit tests for ModelInvocationUtils. + +Covers success and error branches for ModelInvocationUtils, including +InvokeModelError and invoke error mappings for InvokeAuthorizationError, +InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, and +InvokeServerUnavailableError. Assumes mocked model instances and managers. +""" + +from __future__ import annotations + +from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = ( + SimpleNamespace(model_properties=schema or {}) if schema is not None else None + ) + return SimpleNamespace( + provider="provider", + model="model-a", + model_name="model-a", + credentials={"api_key": "x"}, + model_type_instance=model_type_instance, + get_llm_num_tokens=lambda prompt_messages: 5, + invoke_llm=Mock(), + ) + + +@pytest.mark.parametrize( + ("model_instance", "expected", "error_match"), + [ + (None, None, "Model not found"), + (_mock_model_instance(schema=None), None, "No model schema found"), + (_mock_model_instance(schema={}), 2048, None), + (_mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 8192}), 8192, None), + ], + ids=[ + "missing-model", + "missing-schema", + "default-context-size", + "schema-context-size", + ], +) +def test_get_max_llm_context_tokens_branches(model_instance, expected, error_match): + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + if error_match: + with pytest.raises(InvokeModelError, match=error_match): + ModelInvocationUtils.get_max_llm_context_tokens("tenant") + else: + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + + +def test_calculate_tokens_handles_missing_model(): + manager = Mock() + manager.get_default_model_instance.return_value = None + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with pytest.raises(InvokeModelError, match="Model not found"): + ModelInvocationUtils.calculate_tokens("tenant", []) + + +def test_invoke_success_and_error_mappings(): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.return_value = SimpleNamespace( + message=SimpleNamespace(content="ok"), + usage=SimpleNamespace( + completion_tokens=7, + completion_unit_price=Decimal("0.1"), + completion_price_unit=Decimal(1), + latency=0.3, + total_price=Decimal("0.7"), + currency="USD", + ), + ) + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + response = ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) + + assert response.message.content == "ok" + assert db_mock.session.add.call_count == 1 + assert db_mock.session.commit.call_count == 2 + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (InvokeRateLimitError("rate"), "Invoke rate limit error"), + (InvokeBadRequestError("bad"), "Invoke bad request error"), + (InvokeConnectionError("conn"), "Invoke connection error"), + (InvokeAuthorizationError("auth"), "Invoke authorization error"), + (InvokeServerUnavailableError("down"), "Invoke server unavailable error"), + (RuntimeError("oops"), "Invoke error"), + ], + ids=[ + "rate-limit", + "bad-request", + "connection", + "authorization", + "server-unavailable", + "generic-error", + ], +) +def test_invoke_error_mappings(exc, expected): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.side_effect = exc + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + with pytest.raises(InvokeModelError, match=expected): + ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index f39158aa59..40f91b12a0 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,6 +1,12 @@ +from json.decoder import JSONDecodeError +from unittest.mock import Mock, patch + import pytest from flask import Flask +from yaml import YAMLError +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError from core.tools.utils.parser import ApiBasedToolSchemaParser @@ -189,3 +195,225 @@ def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): available_param = params_by_name["available"] assert available_param.type == "boolean" assert available_param.default is True + + +def test_sanitize_default_value_and_type_detection(): + assert ApiBasedToolSchemaParser._sanitize_default_value([]) is None + assert ApiBasedToolSchemaParser._sanitize_default_value({}) is None + assert ApiBasedToolSchemaParser._sanitize_default_value("ok") == "ok" + + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"format": "binary"}) == ToolParameter.ToolParameterType.FILE + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "integer"}) == ToolParameter.ToolParameterType.NUMBER + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"schema": {"type": "boolean"}}) + == ToolParameter.ToolParameterType.BOOLEAN + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"format": "binary"}}) + == ToolParameter.ToolParameterType.FILES + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"type": "string"}}) + == ToolParameter.ToolParameterType.ARRAY + ) + assert ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "object"}) is None + + +def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "API", "version": "1.0.0", "description": "API description"}, + "servers": [ + {"url": "https://dev.example.com", "env": "dev"}, + {"url": "https://prod.example.com", "env": "prod"}, + ], + "paths": { + "/items": { + "post": { + "description": "Create item", + "parameters": [ + {"$ref": "#/components/parameters/token"}, + {"name": "token", "schema": {"type": "string"}}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ItemRequest"}}} + }, + } + } + }, + "components": { + "parameters": { + "token": {"name": "token", "required": True, "schema": {"type": "string"}}, + }, + "schemas": { + "ItemRequest": { + "type": "object", + "required": ["age"], + "properties": {"age": {"type": "integer", "description": "Age", "default": 18}}, + } + }, + }, + } + + extra_info: dict = {} + warning: dict = {} + with app.test_request_context(headers={"X-Request-Env": "prod"}): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + assert len(bundles) == 1 + assert bundles[0].server_url == "https://prod.example.com/items" + assert warning["duplicated_parameter"].startswith("Parameter token") + assert extra_info["description"] == "API description" + + +def test_parse_openapi_to_tool_bundle_no_server_raises(app): + openapi = {"info": {"title": "x"}, "servers": [], "paths": {}} + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="No server found"): + ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + +def test_parse_openapi_yaml_to_tool_bundle_invalid_yaml(app): + with app.test_request_context(): + with pytest.raises(ToolApiSchemaError, match="Invalid openapi yaml"): + ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle("null") + + +def test_parse_swagger_to_openapi_branches(): + with pytest.raises(ToolApiSchemaError, match="No server found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"info": {}, "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No paths found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"servers": [{"url": "https://x"}], "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No operationId found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"summary": "x", "responses": {}}}}, + } + ) + + warning: dict = {"seed": True} + converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"operationId": "getA", "responses": {}}}}, + "definitions": {"A": {"type": "object"}}, + }, + warning=warning, + ) + assert converted["openapi"] == "3.0.0" + assert converted["components"]["schemas"]["A"]["type"] == "object" + assert warning["missing_summary"].startswith("No summary or description found") + + +def test_parse_openai_plugin_json_branches(app): + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="Invalid openai plugin json"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle("{bad") + + with pytest.raises(ToolNotSupportedError, match="Only openapi is supported"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "graphql"}}' + ) + + +def test_parse_openai_plugin_json_http_branches(app): + with app.test_request_context(): + response = type("Resp", (), {"status_code": 500, "text": "", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=response): + with pytest.raises(ToolProviderNotFoundError, match="cannot get openapi yaml"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + response.close.assert_called_once() + + success_response = type("Resp", (), {"status_code": 200, "text": "openapi: 3.0.0", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=success_response): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle", + return_value=["bundle"], + ) as mock_parse: + bundles = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + assert bundles == ["bundle"] + mock_parse.assert_called_once() + success_response.close.assert_called_once() + + +def test_auto_parse_json_yaml_failure(): + with patch("core.tools.utils.parser.json_loads", side_effect=JSONDecodeError("bad", "x", 0)): + with patch("core.tools.utils.parser.safe_load", side_effect=YAMLError("bad yaml")): + with pytest.raises(ToolApiSchemaError, match="Invalid api schema, schema is neither json nor yaml"): + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(":::") + + +def test_auto_parse_openapi_success(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + return_value=["openapi-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["openapi-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAPI + + +def test_auto_parse_openapi_then_swagger(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + loaded_content = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + converted_swagger = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=[ToolApiSchemaError("openapi error"), ["swagger-bundle"]], + ) as mock_parse_openapi: + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + return_value=converted_swagger, + ) as mock_parse_swagger: + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["swagger-bundle"] + assert schema_type == ApiProviderSchemaType.SWAGGER + mock_parse_swagger.assert_called_once_with(loaded_content, extra_info={}, warning={}) + assert mock_parse_openapi.call_count == 2 + mock_parse_openapi.assert_any_call(loaded_content, extra_info={}, warning={}) + mock_parse_openapi.assert_any_call(converted_swagger, extra_info={}, warning={}) + + +def test_auto_parse_openapi_swagger_then_plugin(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=ToolApiSchemaError("openapi error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + side_effect=ToolApiSchemaError("swagger error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle", + return_value=["plugin-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["plugin-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAI_PLUGIN diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py new file mode 100644 index 0000000000..5691f33e65 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import pytest + +from core.tools.utils import system_oauth_encryption as oauth_encryption +from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter + + +def test_system_oauth_encrypter_roundtrip(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} + + encrypted = encrypter.encrypt_oauth_params(payload) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert encrypted + assert dict(decrypted) == payload + + +def test_system_oauth_encrypter_decrypt_validates_input(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(ValueError, match="must be a string"): + encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="cannot be empty"): + encrypter.decrypt_oauth_params("") + + +def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(OAuthEncryptionError, match="Decryption failed"): + encrypter.decrypt_oauth_params("not-base64") + + +def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") + + first = oauth_encryption.get_system_oauth_encrypter() + second = oauth_encryption.get_system_oauth_encrypter() + assert first is second + + encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) + assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + + +def test_create_system_oauth_encrypter_factory(): + encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemOAuthEncrypter) diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index c46e31d90f..dd79b79718 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,7 +1,9 @@ import pytest +from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): @@ -31,3 +33,91 @@ def test_ensure_no_human_input_nodes_raises_for_human_input(): WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + +def test_get_workflow_graph_variables_and_outputs(): + graph = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "variables": [ + { + "variable": "query", + "label": "Query", + "type": "text-input", + "required": True, + } + ], + }, + }, + { + "id": "end-1", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "string", "value_selector": ["n1", "answer"]}, + {"variable": "score", "value_type": "number", "value_selector": ["n1", "score"]}, + ], + }, + }, + { + "id": "end-2", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "object", "value_selector": ["n2", "answer"]}, + ], + }, + }, + ] + } + + variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) + assert len(variables) == 1 + assert variables[0].variable == "query" + assert variables[0].type == VariableEntityType.TEXT_INPUT + + outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph) + assert [output.variable for output in outputs] == ["answer", "score"] + assert outputs[0].value_type == "object" + assert outputs[1].value_type == "number" + + no_start = WorkflowToolConfigurationUtils.get_workflow_graph_variables({"nodes": []}) + assert no_start == [] + + +def test_check_is_synced_validation(): + variables = [ + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + required=True, + ) + ] + configs = [ + WorkflowToolParameterConfiguration( + name="query", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ] + + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=configs) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=[]) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced( + variables=variables, + tool_configurations=[ + WorkflowToolParameterConfiguration( + name="other", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ], + ) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py new file mode 100644 index 0000000000..dd140cbb27 --- /dev/null +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +def _controller() -> WorkflowToolProviderController: + entity = ToolProviderEntity( + identity=ToolProviderIdentity( + author="author", + name="wf-provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="WF"), + ), + credentials_schema=[], + ) + return WorkflowToolProviderController(entity=entity, provider_id="provider-1") + + +def _mock_session_with_begin() -> Mock: + session = Mock() + begin_cm = Mock() + begin_cm.__enter__ = Mock(return_value=None) + begin_cm.__exit__ = Mock(return_value=False) + session.begin.return_value = begin_cm + return session + + +def test_get_db_provider_tool_builds_entity(): + controller = _controller() + session = Mock() + workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) + session.query.return_value.where.return_value.first.return_value = workflow + app = SimpleNamespace(id="app-1") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + user_id="user-1", + parameter_configurations=[ + SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM), + SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM), + ], + ) + user = SimpleNamespace(name="Alice") + variables = [ + VariableEntity( + variable="country", + label="Country", + description="Country", + type=VariableEntityType.SELECT, + required=True, + options=["US", "IN"], + ) + ] + outputs = [ + SimpleNamespace(variable="json", value_type="string"), + SimpleNamespace(variable="answer", value_type="string"), + ] + + with ( + patch( + "core.tools.workflow_as_tool.provider.WorkflowAppConfigManager.convert_features", + return_value=SimpleNamespace(file_upload=True), + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_variables", + return_value=variables, + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_output", + return_value=outputs, + ), + ): + tool = controller._get_db_provider_tool(db_provider, app, session=session, user=user) + + assert tool.entity.identity.name == "workflow_tool" + # "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out. + assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}} + assert "json" not in tool.entity.output_schema["properties"] + assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT + assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES + assert controller.provider_type == ToolProviderType.WORKFLOW + + +def test_get_tool_returns_hit_or_none(): + controller = _controller() + tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool"))) + controller.tools = [tool] + + assert controller.get_tool("workflow_tool") is tool + assert controller.get_tool("missing") is None + + +def test_get_tools_returns_cached(): + controller = _controller() + cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))] + controller.tools = cached_tools # type: ignore[assignment] + + assert controller.get_tools("tenant-1") == cached_tools + + +def test_from_db_builds_controller(): + controller = _controller() + + app = SimpleNamespace(id="app-1") + user = SimpleNamespace(name="Alice") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.side_effect = [app, user] + fake_cm = MagicMock() + fake_cm.__enter__.return_value = session + fake_cm.__exit__.return_value = False + fake_session_factory = Mock() + fake_session_factory.create_session.return_value = fake_cm + + with patch("core.tools.workflow_as_tool.provider.session_factory", fake_session_factory): + with patch.object( + WorkflowToolProviderController, + "_get_db_provider_tool", + return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))), + ): + built = WorkflowToolProviderController.from_db(db_provider) + assert isinstance(built, WorkflowToolProviderController) + assert built.tools + + +def test_get_tools_returns_empty_when_provider_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = None + session_cls.return_value.__enter__.return_value = session + + assert controller.get_tools("tenant-1") == [] + + +def test_get_tools_raises_when_app_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.return_value = None + session_cls.return_value.__enter__.return_value = session + with pytest.raises(ValueError, match="app not found"): + controller.get_tools("tenant-1") diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 36fdb0218c..cc00f79698 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -1,20 +1,85 @@ +"""Unit tests for workflow-as-tool behavior. + +StubSession/StubScalars emulate SQLAlchemy session/scalars with minimal methods +(`scalar`, `scalars`, `expunge`, `commit`, `refresh`, context manager) to keep +database access mocked and predictable in tests. +""" + +import json from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.file import FILE_MODEL_IDENTITY -def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): - """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when - `WorkflowAppGenerator.generate` returns a result with `error` key inside - the `data` element. - """ +class StubScalars: + """Minimal stub for SQLAlchemy scalar results.""" + + _value: Any + + def __init__(self, value: Any) -> None: + self._value = value + + def first(self) -> Any: + return self._value + + +class StubSession: + """Minimal stub for session_factory-created sessions.""" + + scalar_results: list[Any] + scalars_results: list[Any] + expunge_calls: list[object] + + def __init__(self, *, scalar_results: list[Any] | None = None, scalars_results: list[Any] | None = None) -> None: + self.scalar_results = list(scalar_results or []) + self.scalars_results = list(scalars_results or []) + self.expunge_calls: list[object] = [] + + def scalar(self, _stmt: Any) -> Any: + return self.scalar_results.pop(0) + + def scalars(self, _stmt: Any) -> StubScalars: + return StubScalars(self.scalars_results.pop(0)) + + def expunge(self, value: Any) -> None: + self.expunge_calls.append(value) + + def begin(self) -> "StubSession": + return self + + def commit(self) -> None: + pass + + def refresh(self, _value: Any) -> None: + pass + + def close(self) -> None: + pass + + def __enter__(self) -> "StubSession": + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +def _build_tool() -> WorkflowTool: entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), parameters=[], @@ -22,9 +87,9 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel has_runtime_parameters=False, ) runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", + return WorkflowTool( + workflow_app_id="app-1", + workflow_as_tool_id="wf-tool-1", version="1", workflow_entities={}, workflow_call_depth=1, @@ -32,13 +97,19 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel runtime=runtime, ) + +def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): + """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when + `WorkflowAppGenerator.generate` returns a result with `error` key inside + the `data` element. + """ + tool = _build_tool() + # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -56,28 +127,12 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure pause_state_config is passed as None.""" + tool = _build_tool() monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) - from unittest.mock import MagicMock, Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -94,22 +149,7 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # Mock workflow outputs mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}} @@ -119,8 +159,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -134,10 +172,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Execute tool invocation messages = list(tool.invoke("test_user", {})) - # Verify generated messages - # Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages - assert len(messages) == 5 - # Verify variable messages variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE] assert len(variable_messages) == 3 @@ -151,7 +185,7 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Verify text message text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT] assert len(text_messages) == 1 - assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text + assert json.loads(text_messages[0].message.text) == mock_outputs # Verify JSON message json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON] @@ -161,30 +195,13 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should handle empty outputs correctly""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -217,61 +234,32 @@ def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPat assert json_messages[0].message.json_object == {} -def test_create_variable_message(): - """Test the functionality of creating variable messages""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) - - # Test different types of variable values - test_cases = [ +@pytest.mark.parametrize( + ("var_name", "var_value"), + [ ("string_var", "test string"), ("int_var", 42), ("float_var", 3.14), ("bool_var", True), ("list_var", [1, 2, 3]), ("dict_var", {"key": "value"}), - ] + ], +) +def test_create_variable_message(var_name, var_value): + """Create variable messages for multiple value types.""" + tool = _build_tool() - for var_name, var_value in test_cases: - message = tool.create_variable_message(var_name, var_value) + message = tool.create_variable_message(var_name, var_value) - assert message.type == ToolInvokeMessage.MessageType.VARIABLE - assert message.message.variable_name == var_name - assert message.message.variable_value == var_value - assert message.message.stream is False + assert message.type == ToolInvokeMessage.MessageType.VARIABLE + assert message.message.variable_name == var_name + assert message.message.variable_value == var_value + assert message.message.stream is False def test_create_file_message_should_include_file_marker(): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure file message includes marker and meta payload.""" + tool = _build_tool() file_obj = object() message = tool.create_file_message(file_obj) # type: ignore[arg-type] @@ -284,103 +272,247 @@ def test_create_file_message_should_include_file_marker(): def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch): """Ensure worker context can resolve EndUser when Account is missing.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - # SQLAlchemy Session APIs used by code under test - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - # support `with session_factory.create_session() as session:` - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") # Monkeypatch session factory to return our stub session + stub_session = StubSession(scalar_results=[tenant, None, end_user]) monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([tenant, None, end_user]), + lambda: stub_session, ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "tenant_id" resolved_user = tool._resolve_user_from_database(user_id=end_user.id) assert resolved_user is end_user + assert stub_session.expunge_calls == [end_user] def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch): """Return None if tenant cannot be found in worker context.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - # Monkeypatch session factory to return our stub session with no tenant monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([None]), + lambda: StubSession(scalar_results=[None]), ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "missing_tenant" resolved_user = tool._resolve_user_from_database(user_id="any") assert resolved_user is None + + +def test_workflow_tool_provider_type_and_fork_runtime(): + """Verify provider type and forked runtime behavior.""" + tool = _build_tool() + assert tool.tool_provider_type() == ToolProviderType.WORKFLOW + assert tool.latest_usage.total_tokens == 0 + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2", invoke_from=InvokeFrom.DEBUGGER)) + assert isinstance(forked, WorkflowTool) + assert forked.workflow_app_id == tool.workflow_app_id + assert forked.runtime.tenant_id == "tenant-2" + + +def test_derive_usage_from_top_level_usage_key(): + """Derive usage from top-level usage dict.""" + usage = WorkflowTool._derive_usage_from_result({"usage": {"total_tokens": 12, "total_price": "0.2"}}) + assert usage.total_tokens == 12 + + +def test_derive_usage_from_metadata_usage(): + """Derive usage from metadata usage dict.""" + metadata_usage = WorkflowTool._derive_usage_from_result({"metadata": {"usage": {"total_tokens": 7}}}) + assert metadata_usage.total_tokens == 7 + + +def test_derive_usage_from_totals(): + """Derive usage from top-level totals fields.""" + totals_usage = WorkflowTool._derive_usage_from_result( + {"total_tokens": "9", "total_price": "1.3", "currency": "USD"} + ) + assert totals_usage.total_tokens == 9 + assert str(totals_usage.total_price) == "1.3" + + +def test_derive_usage_from_empty(): + """Default usage values when result is empty.""" + empty_usage = WorkflowTool._derive_usage_from_result({}) + assert empty_usage.total_tokens == 0 + + +def test_extract_usage_from_nested(): + """Extract nested usage dict from result payloads.""" + nested = WorkflowTool._extract_usage_dict({"nested": [{"data": {"usage": {"total_tokens": 3}}}]}) + assert nested == {"total_tokens": 3} + + +def test_invoke_raises_when_user_not_found(monkeypatch: pytest.MonkeyPatch): + """Raise ToolInvokeError when user resolution fails.""" + tool = _build_tool() + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: None) + + with pytest.raises(ToolInvokeError, match="User not found"): + list(tool.invoke("missing", {})) + + +def test_resolve_user_from_database_returns_account(monkeypatch: pytest.MonkeyPatch): + """Resolve Account and set tenant in worker context.""" + tenant = SimpleNamespace(id="tenant_id") + account = SimpleNamespace(id="account_id", current_tenant=None) + session = StubSession(scalar_results=[tenant, account]) + + monkeypatch.setattr("core.tools.workflow_as_tool.tool.session_factory.create_session", lambda: session) + tool = _build_tool() + tool.runtime.tenant_id = "tenant_id" + + resolved = tool._resolve_user_from_database(user_id="account_id") + assert resolved is account + assert account.current_tenant is tenant + assert session.expunge_calls == [account] + + +def test_get_workflow_and_get_app_db_branches(monkeypatch: pytest.MonkeyPatch): + """Cover workflow/app retrieval branches and error cases.""" + tool = _build_tool() + latest_workflow = SimpleNamespace(id="wf-latest") + specific_workflow = SimpleNamespace(id="wf-v1") + app = SimpleNamespace(id="app-1") + sessions = iter( + [ + StubSession(scalar_results=[], scalars_results=[latest_workflow]), + StubSession(scalar_results=[specific_workflow], scalars_results=[]), + StubSession(scalar_results=[app], scalars_results=[]), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: next(sessions), + ) + + assert tool._get_workflow("app-1", "") is latest_workflow + assert tool._get_workflow("app-1", "1") is specific_workflow + assert tool._get_app("app-1") is app + + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession(scalar_results=[None, None], scalars_results=[None]), + ) + with pytest.raises(ValueError, match="workflow not found"): + tool._get_workflow("app-1", "1") + with pytest.raises(ValueError, match="app not found"): + tool._get_app("app-1") + + +def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: + """Build a WorkflowTool and stub merged runtime parameters for files/query.""" + tool = _build_tool() + files_param = ToolParameter.get_simple_instance( + name="files", + llm_description="files", + typ=ToolParameter.ToolParameterType.SYSTEM_FILES, + required=False, + ) + files_param.form = ToolParameter.ToolParameterForm.FORM + text_param = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + + monkeypatch.setattr(tool, "get_merged_runtime_parameters", lambda: [files_param, text_param]) + return tool + + +def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): + """Transform args into parameters and files payloads.""" + tool = _setup_transform_args_tool(monkeypatch) + + params, files = tool._transform_args( + { + "query": "hello", + "files": [ + { + "tenant_id": "tenant-1", + "type": "image", + "transfer_method": "tool_file", + "related_id": "tool-1", + "extension": ".png", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "local_file", + "related_id": "upload-1", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "remote_url", + "remote_url": "https://example.com/a.pdf", + }, + ], + } + ) + assert params == {"query": "hello"} + assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) + assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) + assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + + +def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): + """Ignore invalid file entries while keeping params.""" + tool = _setup_transform_args_tool(monkeypatch) + invalid_params, invalid_files = tool._transform_args({"query": "hello", "files": [{"invalid": True}]}) + assert invalid_params == {"query": "hello"} + assert invalid_files == [] + + +def test_extract_files(): + """Extract file outputs into result and file list.""" + tool = _build_tool() + built_files = [ + SimpleNamespace(id="file-1"), + SimpleNamespace(id="file-2"), + ] + with patch("core.tools.workflow_as_tool.tool.build_from_mapping", side_effect=built_files): + outputs = { + "attachments": [ + { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "tool_file", + "related_id": "r1", + } + ], + "single_file": { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "local_file", + "related_id": "r2", + }, + "text": "ok", + } + result, extracted_files = tool._extract_files(outputs) + + assert result["text"] == "ok" + assert len(extracted_files) == 2 + + +def test_update_file_mapping(): + """Map tool/local file transfer methods into output shape.""" + tool = _build_tool() + tool_file = tool._update_file_mapping({"transfer_method": "tool_file", "related_id": "tool-1"}) + assert tool_file["tool_file_id"] == "tool-1" + local_file = tool._update_file_mapping({"transfer_method": "local_file", "related_id": "upload-1"}) + assert local_file["upload_file_id"] == "upload-1" diff --git a/api/tests/unit_tests/tools/__init__.py b/api/tests/unit_tests/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2