test: Unit test cases for core.tools module (#32447)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
Co-authored-by: mahammadasim <135003320+mahammadasim@users.noreply.github.com>
This commit is contained in:
Rajat Agarwal 2026-03-12 09:18:13 +05:30 committed by GitHub
parent e99628b76f
commit b170eabaf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 5008 additions and 196 deletions

View File

@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
:param credential_type: the type of the credential, as CredentialType or str; str values
are normalized via CredentialType.of and may raise ValueError for invalid values.
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
empty list for CredentialType.UNAUTHORIZED or missing schemas.
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
Raises ValueError for invalid credential types.
"""
if credential_type == CredentialType.OAUTH2.value:
if isinstance(credential_type, str):
credential_type = CredentialType.of(credential_type)
if credential_type == CredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
if credential_type == CredentialType.UNAUTHORIZED:
return []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:

View File

@ -137,6 +137,7 @@ class ToolFileManager:
session.add(tool_file)
session.commit()
session.refresh(tool_file)
return tool_file

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from collections.abc import Generator
from types import SimpleNamespace
from typing import Any
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
class _BuiltinDummyTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("ok")
def _build_tool() -> _BuiltinDummyTool:
entity = ToolEntity(
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
parameters=[],
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
def test_builtin_tool_fork_and_provider_type():
tool = _build_tool()
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
assert isinstance(forked, _BuiltinDummyTool)
assert forked.runtime.tenant_id == "tenant-2"
assert tool.tool_provider_type() == ToolProviderType.BUILT_IN
def test_invoke_model_calls_model_invocation_utils_invoke():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
assert (
tool.invoke_model(
user_id="u1",
prompt_messages=[UserPromptMessage(content="hello")],
stop=[],
)
== "result"
)
mock_invoke.assert_called_once()
def test_get_max_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
assert tool.get_max_tokens() == 4096
def test_get_prompt_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
def test_runtime_none_raises():
tool = _build_tool()
tool.runtime = None
with pytest.raises(ValueError, match="runtime is required"):
tool.get_max_tokens()
with pytest.raises(ValueError, match="runtime is required"):
tool.get_prompt_tokens([UserPromptMessage(content="hello")])
def test_builtin_tool_summary_short_and_long_content_paths():
tool = _build_tool()
with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=100):
with patch.object(_BuiltinDummyTool, "get_prompt_tokens", return_value=10):
assert tool.summary(user_id="u1", content="short") == "short"
with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=10):
with patch.object(
_BuiltinDummyTool,
"get_prompt_tokens",
side_effect=lambda prompt_messages: len(prompt_messages[-1].content),
):
with patch.object(
_BuiltinDummyTool,
"invoke_model",
return_value=SimpleNamespace(message=SimpleNamespace(content="S")),
):
result = tool.summary(user_id="u1", content="x" * 30 + "\n" + "y" * 5)
assert result
assert "S" in result

View File

@ -0,0 +1,216 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderEntity, ToolProviderType
from core.tools.errors import ToolProviderNotFoundError
class _FakeBuiltinTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("ok")
class _ConcreteBuiltinProvider(BuiltinToolProviderController):
last_validation: tuple[str, dict[str, Any]] | None = None
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
self.last_validation = (user_id, credentials)
def _provider_yaml() -> dict[str, Any]:
return {
"identity": {
"author": "Dify",
"name": "fake_provider",
"label": {"en_US": "Fake Provider"},
"description": {"en_US": "Fake description"},
"icon": "icon.svg",
"tags": ["utilities"],
},
"credentials_for_provider": {
"api_key": {
"type": "secret-input",
"required": True,
}
},
"oauth_schema": {
"client_schema": [
{
"name": "client_id",
"type": "text-input",
}
],
"credentials_schema": [
{
"name": "access_token",
"type": "secret-input",
}
],
},
}
def _tool_yaml() -> dict[str, Any]:
return {
"identity": {
"author": "Dify",
"name": "tool_a",
"label": {"en_US": "Tool A"},
},
"parameters": [],
}
def test_builtin_tool_provider_init_load_tools_and_basic_accessors(monkeypatch):
yaml_payloads = [_provider_yaml(), _tool_yaml()]
def _load_yaml(*args, **kwargs):
return yaml_payloads.pop(0)
monkeypatch.setattr("core.tools.builtin_tool.provider.load_yaml_file_cached", _load_yaml)
monkeypatch.setattr(
"core.tools.builtin_tool.provider.listdir",
lambda *args, **kwargs: ["tool_a.yaml", "__init__.py", "readme.md"],
)
monkeypatch.setattr(
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
lambda *args, **kwargs: _FakeBuiltinTool,
)
provider = _ConcreteBuiltinProvider()
assert provider.get_credentials_schema()
assert provider.get_tools()
assert provider.get_tool("tool_a") is not None
assert provider.get_tool("missing") is None
assert provider.provider_type == ToolProviderType.BUILT_IN
assert provider.tool_labels == ["utilities"]
assert provider.need_credentials is True
oauth_schema = provider.get_credentials_schema_by_type(CredentialType.OAUTH2)
assert len(oauth_schema) == 1
api_schema = provider.get_credentials_schema_by_type(CredentialType.API_KEY)
assert len(api_schema) == 1
assert provider.get_oauth_client_schema()[0].name == "client_id"
assert set(provider.get_supported_credential_types()) == {CredentialType.API_KEY, CredentialType.OAUTH2}
def test_builtin_tool_provider_invalid_credential_type_raises():
with (
patch(
"core.tools.builtin_tool.provider.load_yaml_file_cached",
side_effect=[_provider_yaml(), _tool_yaml()],
),
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
patch(
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
return_value=_FakeBuiltinTool,
),
):
provider = _ConcreteBuiltinProvider()
with pytest.raises(ValueError, match="Invalid credential type: invalid"):
provider.get_credentials_schema_by_type("invalid")
def test_builtin_tool_provider_validate_credentials_delegates():
with (
patch(
"core.tools.builtin_tool.provider.load_yaml_file_cached",
side_effect=[_provider_yaml(), _tool_yaml()],
),
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
patch(
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
return_value=_FakeBuiltinTool,
),
):
provider = _ConcreteBuiltinProvider()
provider.validate_credentials("user-1", {"api_key": "secret"})
assert provider.last_validation == ("user-1", {"api_key": "secret"})
def test_builtin_tool_provider_unauthorized_schema_is_empty():
with (
patch(
"core.tools.builtin_tool.provider.load_yaml_file_cached",
side_effect=[_provider_yaml(), _tool_yaml()],
),
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
patch(
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
return_value=_FakeBuiltinTool,
),
):
provider = _ConcreteBuiltinProvider()
assert provider.get_credentials_schema_by_type(CredentialType.UNAUTHORIZED) == []
def test_builtin_tool_provider_init_raises_when_provider_yaml_missing():
with patch("core.tools.builtin_tool.provider.load_yaml_file_cached", side_effect=RuntimeError("boom")):
with pytest.raises(ToolProviderNotFoundError, match="can not load provider yaml"):
_ConcreteBuiltinProvider()
def test_builtin_tool_provider_handles_empty_credentials_and_oauth():
provider = object.__new__(_ConcreteBuiltinProvider)
provider.tools = []
provider.entity = ToolProviderEntity.model_validate(
{
"identity": {
"author": "Dify",
"name": "fake_provider",
"label": {"en_US": "Fake Provider"},
"description": {"en_US": "Fake description"},
"icon": "icon.svg",
"tags": None,
},
"credentials_schema": [],
"oauth_schema": None,
},
)
assert provider.get_oauth_client_schema() == []
assert provider.get_supported_credential_types() == []
assert provider.need_credentials is False
assert provider._get_tool_labels() == []
def test_builtin_tool_provider_forked_tool_runtime_is_initialized():
with (
patch(
"core.tools.builtin_tool.provider.load_yaml_file_cached",
side_effect=[_provider_yaml(), _tool_yaml()],
),
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
patch(
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
return_value=_FakeBuiltinTool,
),
):
provider = _ConcreteBuiltinProvider()
tool = provider.get_tool("tool_a")
assert tool is not None
assert isinstance(tool.runtime, ToolRuntime)
assert tool.runtime.tenant_id == ""
tool.runtime.invoke_from = InvokeFrom.DEBUGGER
assert tool.runtime.invoke_from == InvokeFrom.DEBUGGER

View File

@ -0,0 +1,310 @@
from __future__ import annotations
import math
from types import SimpleNamespace
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.providers.audio.audio import AudioToolProvider
from core.tools.builtin_tool.providers.audio.tools.asr import ASRTool
from core.tools.builtin_tool.providers.audio.tools.tts import TTSTool
from core.tools.builtin_tool.providers.code.code import CodeToolProvider
from core.tools.builtin_tool.providers.code.tools.simple_code import SimpleCode
from core.tools.builtin_tool.providers.time.time import WikiPediaProvider
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
from core.tools.builtin_tool.providers.time.tools.localtime_to_timestamp import LocaltimeToTimestampTool
from core.tools.builtin_tool.providers.time.tools.timestamp_to_localtime import TimestampToLocaltimeTool
from core.tools.builtin_tool.providers.time.tools.timezone_conversion import TimezoneConversionTool
from core.tools.builtin_tool.providers.time.tools.weekday import WeekdayTool
from core.tools.builtin_tool.providers.webscraper.tools.webscraper import WebscraperTool
from core.tools.builtin_tool.providers.webscraper.webscraper import WebscraperProvider
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from dify_graph.file.enums import FileType
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool:
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="tool-a",
label=I18nObject(en_US="tool-a"),
provider="provider-a",
),
parameters=[],
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
return tool_cls(provider="provider-a", entity=entity, runtime=runtime)
def _raise_runtime_error(*_args: object, **_kwargs: object) -> None:
raise RuntimeError("boom")
def test_current_time_tool():
current_tool = _build_builtin_tool(CurrentTimeTool)
utc_text = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "UTC"}))[0].message.text
assert utc_text
invalid_tz = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "Invalid/TZ"}))[0].message.text
assert "Invalid timezone" in invalid_tz
def test_localtime_to_timestamp_tool():
localtime_tool = _build_builtin_tool(LocaltimeToTimestampTool)
ts_message = list(
localtime_tool.invoke(user_id="u", tool_parameters={"localtime": "2024-01-01 10:00:00", "timezone": "UTC"})
)[0].message.text
ts_value = float(ts_message.strip())
assert math.isfinite(ts_value)
assert ts_value >= 0
with pytest.raises(ToolInvokeError):
LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC")
def test_timestamp_to_localtime_tool():
to_local_tool = _build_builtin_tool(TimestampToLocaltimeTool)
local_text = list(to_local_tool.invoke(user_id="u", tool_parameters={"timestamp": 1704067200, "timezone": "UTC"}))[
0
].message.text
assert "2024" in local_text
with pytest.raises(ToolInvokeError):
TimestampToLocaltimeTool.timestamp_to_localtime("bad", "UTC") # type: ignore[arg-type]
def test_timezone_conversion_tool():
timezone_tool = _build_builtin_tool(TimezoneConversionTool)
converted = list(
timezone_tool.invoke(
user_id="u",
tool_parameters={
"current_time": "2024-01-01 08:00:00",
"current_timezone": "UTC",
"target_timezone": "Asia/Tokyo",
},
)
)[0].message.text
assert converted.startswith("2024-01-01")
with pytest.raises(ToolInvokeError):
TimezoneConversionTool.timezone_convert("bad", "UTC", "Asia/Tokyo")
def test_weekday_tool():
weekday_tool = _build_builtin_tool(WeekdayTool)
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
assert "January 1, 2024" in valid
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
0
].message.text
assert "Invalid date" in invalid
with pytest.raises(ValueError, match="Month is required"):
list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "day": 1}))
def test_simple_code_valid_execution(monkeypatch):
simple_code = _build_builtin_tool(SimpleCode)
monkeypatch.setattr(
"core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code",
lambda *a: "ok",
)
result = list(
simple_code.invoke(
user_id="u",
tool_parameters={"language": "python3", "code": "print(1)"},
)
)[0].message.text
assert result == "ok"
def test_simple_code_invalid_language():
simple_code = _build_builtin_tool(SimpleCode)
with pytest.raises(ValueError, match="Only python3 and javascript"):
list(simple_code.invoke(user_id="u", tool_parameters={"language": "go", "code": "fmt.Println(1)"}))
def test_simple_code_execution_error(monkeypatch):
simple_code = _build_builtin_tool(SimpleCode)
monkeypatch.setattr(
"core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code",
_raise_runtime_error,
)
with pytest.raises(ToolInvokeError, match="boom"):
list(simple_code.invoke(user_id="u", tool_parameters={"language": "python3", "code": "print(1)"}))
def test_webscraper_empty_url():
webscraper = _build_builtin_tool(WebscraperTool)
empty = list(webscraper.invoke(user_id="u", tool_parameters={"url": ""}))[0].message.text
assert empty == "Please input url"
def test_webscraper_fetch(monkeypatch):
webscraper = _build_builtin_tool(WebscraperTool)
monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page")
full = list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))[0].message.text
assert full == "page"
def test_webscraper_summary(monkeypatch):
webscraper = _build_builtin_tool(WebscraperTool)
monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page")
monkeypatch.setattr(webscraper, "summary", lambda user_id, content: "summary")
summarized = list(
webscraper.invoke(
user_id="u",
tool_parameters={"url": "https://example.com", "generate_summary": True},
)
)[0].message.text
assert summarized == "summary"
def test_webscraper_fetch_error(monkeypatch):
webscraper = _build_builtin_tool(WebscraperTool)
monkeypatch.setattr(
"core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url",
_raise_runtime_error,
)
with pytest.raises(ToolInvokeError, match="boom"):
list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))
def test_asr_invalid_file():
asr = _build_builtin_tool(ASRTool)
file_obj = SimpleNamespace(type=FileType.DOCUMENT)
invalid_file = list(asr.invoke(user_id="u", tool_parameters={"audio_file": file_obj}))[0].message.text
assert "not a valid audio file" in invalid_file
def test_asr_valid_file_invocation(monkeypatch):
asr = _build_builtin_tool(ASRTool)
model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})()
model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})()
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes")
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager)
audio_file = SimpleNamespace(type=FileType.AUDIO)
ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text
assert ok == "transcript"
def test_asr_available_models_and_runtime_parameters(monkeypatch):
asr = _build_builtin_tool(ASRTool)
provider_model = type("PM", (), {"provider": "p", "models": [type("Model", (), {"model": "m"})()]})()
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.asr.ModelProviderService.get_models_by_model_type",
lambda *a, **k: [provider_model],
)
assert asr.get_available_models() == [("p", "m")]
assert asr.get_runtime_parameters()[0].name == "model"
def test_tts_invoke_returns_messages(monkeypatch):
tts = _build_builtin_tool(TTSTool)
voices_model_instance = type(
"TTSM",
(),
{
"get_tts_voices": lambda self: [{"value": "voice-1"}],
"invoke_tts": lambda self, **kwargs: [b"a", b"b"],
},
)()
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(),
)
messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB]
def test_tts_get_available_models_requires_runtime():
tts = _build_builtin_tool(TTSTool)
tts.runtime = None
with pytest.raises(ValueError, match="Runtime is required"):
tts.get_available_models()
def test_tts_tool_raises_when_runtime_missing():
tts = _build_builtin_tool(TTSTool)
tts.runtime = None
with pytest.raises(ValueError, match="Runtime is required"):
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
@pytest.mark.parametrize(
"voices",
[[{"value": None}], []],
)
def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices):
tts = _build_builtin_tool(TTSTool)
tts.runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
model_without_voice = type(
"TTSModelNoVoice",
(),
{
"get_tts_voices": lambda self: voices,
"invoke_tts": lambda self, **kwargs: [b"x"],
},
)()
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
)
with pytest.raises(ValueError, match="no voice available"):
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
def test_tts_tool_get_available_models_and_runtime_parameters(monkeypatch):
tts = _build_builtin_tool(TTSTool)
model_1 = SimpleNamespace(
model="model-a",
model_properties={ModelPropertyKey.VOICES: [{"mode": "v1", "name": "Voice 1"}]},
)
model_2 = SimpleNamespace(model="model-b", model_properties={})
provider_models = [SimpleNamespace(provider="provider-a", models=[model_1, model_2])]
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.tts.ModelProviderService.get_models_by_model_type",
lambda *args, **kwargs: provider_models,
)
available_models = tts.get_available_models()
assert available_models == [
("provider-a", "model-a", [{"mode": "v1", "name": "Voice 1"}]),
("provider-a", "model-b", []),
]
runtime_parameters = tts.get_runtime_parameters()
assert runtime_parameters[0].name == "model"
assert runtime_parameters[0].required is True
assert runtime_parameters[0].options[0].value == "provider-a#model-a"
assert runtime_parameters[1].name == "voice#provider-a#model-a"
def test_provider_classes_and_builtin_sort(monkeypatch):
# Use object.__new__ to avoid YAML-loading __init__; only pass-through validation is exercised.
# Ensure pass-through _validate_credentials methods are executed.
AudioToolProvider._validate_credentials(object.__new__(AudioToolProvider), "u", {})
CodeToolProvider._validate_credentials(object.__new__(CodeToolProvider), "u", {})
WikiPediaProvider._validate_credentials(object.__new__(WikiPediaProvider), "u", {})
WebscraperProvider._validate_credentials(object.__new__(WebscraperProvider), "u", {})
providers = [SimpleNamespace(name="b"), SimpleNamespace(name="a")]
monkeypatch.setattr(BuiltinToolProviderSort, "_position", {})
monkeypatch.setattr(
"core.tools.builtin_tool.providers._positions.get_tool_position_map",
lambda _: {"a": 0, "b": 1},
)
monkeypatch.setattr(
"core.tools.builtin_tool.providers._positions.sort_by_position_map",
lambda position, values, name_func: sorted(values, key=lambda x: name_func(x)),
)
sorted_providers = BuiltinToolProviderSort.sort(providers)
assert [p.name for p in sorted_providers] == ["a", "b"]

View File

@ -0,0 +1,285 @@
from __future__ import annotations
from types import SimpleNamespace
import httpx
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.tool import ApiTool, ParsedResponse
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
def _build_tool(*, openapi: dict | None = None) -> ApiTool:
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="tool-a",
label=I18nObject(en_US="tool-a"),
provider="provider-a",
),
parameters=[],
)
bundle = ApiToolBundle(
server_url="https://api.example.com/items/{id}",
method="GET",
summary="summary",
operation_id="op-id",
parameters=[],
author="author",
openapi=openapi or {"parameters": []},
)
runtime = ToolRuntime(
tenant_id="tenant-1",
invoke_from=InvokeFrom.DEBUGGER,
credentials={"auth_type": "api_key_header", "api_key_value": "k"},
)
return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="provider-id")
def test_parsed_response_to_string():
assert ParsedResponse({"a": 1}, True).to_string() == '{"a": 1}'
assert ParsedResponse("ok", False).to_string() == "ok"
def test_api_tool_fork_runtime_and_validate_credentials(monkeypatch):
tool = _build_tool()
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
assert isinstance(forked, ApiTool)
assert forked.runtime.tenant_id == "tenant-2"
tool.api_bundle = None # type: ignore[assignment]
with pytest.raises(ValueError, match="api_bundle is required"):
tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
tool = _build_tool()
assert tool.validate_credentials(credentials={}, parameters={}, format_only=True) == ""
monkeypatch.setattr(tool, "assembling_request", lambda parameters: {"Authorization": "Bearer x"})
monkeypatch.setattr(
tool,
"do_http_request",
lambda url, method, headers, parameters: httpx.Response(200, json={"ok": True}),
)
result = tool.validate_credentials(credentials={}, parameters={"a": 1}, format_only=False)
assert result == '{"ok": true}'
def test_assembling_request_auth_header_assembly():
tool = _build_tool()
headers = tool.assembling_request(parameters={})
assert headers["Authorization"] == "k"
tool.runtime.credentials = {
"auth_type": "api_key_header",
"api_key_header_prefix": "bearer",
"api_key_value": "abc",
}
headers = tool.assembling_request(parameters={})
assert headers["Authorization"] == "Bearer abc"
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_header_prefix": "basic", "api_key_value": "abc"}
headers = tool.assembling_request(parameters={})
assert headers["Authorization"] == "Basic abc"
tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_value": "abc"}
assert tool.assembling_request(parameters={}) == {}
def test_assembling_request_runtime_auth_errors():
tool = _build_tool()
tool.runtime = None
with pytest.raises(ToolProviderCredentialValidationError, match="runtime not initialized"):
tool.assembling_request(parameters={})
tool.runtime = ToolRuntime(tenant_id="tenant", credentials={})
with pytest.raises(ToolProviderCredentialValidationError, match="Missing auth_type"):
tool.assembling_request(parameters={})
tool.runtime.credentials = {"auth_type": "api_key_header"}
with pytest.raises(ToolProviderCredentialValidationError, match="Missing api_key_value"):
tool.assembling_request(parameters={})
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": 123}
with pytest.raises(ToolProviderCredentialValidationError, match="must be a string"):
tool.assembling_request(parameters={})
def test_assembling_request_parameter_validation_and_defaults():
tool = _build_tool()
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "x"}
tool.api_bundle.parameters = [
SimpleNamespace(required=True, name="required_param", default=None),
]
with pytest.raises(ToolParameterValidationError, match="Missing required parameter required_param"):
tool.assembling_request(parameters={})
tool.api_bundle.parameters = [
SimpleNamespace(required=True, name="required_param", default="d"),
]
params = {}
tool.assembling_request(parameters=params)
assert params["required_param"] == "d"
def test_validate_and_parse_response_branches():
tool = _build_tool()
with pytest.raises(ToolInvokeError, match="status code 500"):
tool.validate_and_parse_response(httpx.Response(500, text="boom"))
empty = tool.validate_and_parse_response(httpx.Response(200, content=b""))
assert empty.is_json is False
assert "Empty response from the tool" in str(empty.content)
json_resp = tool.validate_and_parse_response(
httpx.Response(200, json={"a": 1}, headers={"content-type": "application/json"})
)
assert json_resp.is_json is True
assert json_resp.content == {"a": 1}
non_json_type = tool.validate_and_parse_response(
httpx.Response(200, text='{"a": 1}', headers={"content-type": "text/plain"})
)
assert non_json_type.is_json is False
assert non_json_type.content == '{"a": 1}'
plain_resp = tool.validate_and_parse_response(httpx.Response(200, text="plain"))
assert plain_resp.is_json is False
assert plain_resp.content == "plain"
with pytest.raises(ValueError, match="Invalid response type"):
tool.validate_and_parse_response("invalid") # type: ignore[arg-type]
def test_get_parameter_value_and_type_conversion_helpers():
tool = _build_tool()
assert tool.get_parameter_value({"name": "x"}, {"x": 1}) == 1
assert tool.get_parameter_value({"name": "x", "required": False, "schema": {"default": "d"}}, {}) == "d"
with pytest.raises(ToolParameterValidationError, match="Missing required parameter x"):
tool.get_parameter_value({"name": "x", "required": True}, {})
assert tool._convert_body_property_any_of({}, "12", [{"type": "integer"}]) == 12
assert tool._convert_body_property_any_of({}, "1.5", [{"type": "number"}]) == 1.5
assert tool._convert_body_property_any_of({}, "true", [{"type": "boolean"}]) is True
assert tool._convert_body_property_any_of({}, "", [{"type": "null"}]) is None
assert tool._convert_body_property_any_of({}, "x", [{"anyOf": [{"type": "string"}]}]) == "x"
assert tool._convert_body_property_type({"type": "integer"}, "1") == 1
assert tool._convert_body_property_type({"type": "number"}, "1.2") == 1.2
assert tool._convert_body_property_type({"type": "string"}, 1) == "1"
assert tool._convert_body_property_type({"type": "boolean"}, 1) is True
assert tool._convert_body_property_type({"type": "null"}, None) is None
assert tool._convert_body_property_type({"type": "object"}, '{"a":1}') == {"a": 1}
assert tool._convert_body_property_type({"type": "array"}, "[1,2]") == [1, 2]
assert tool._convert_body_property_type({"type": "invalid"}, "v") == "v"
assert tool._convert_body_property_type({"anyOf": [{"type": "integer"}]}, "2") == 2
def test_do_http_request_builds_arguments_and_handles_invalid_method(monkeypatch):
openapi = {
"parameters": [
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}},
{"name": "q", "in": "query", "required": False, "schema": {"default": ""}},
{"name": "X-Extra", "in": "header", "required": False, "schema": {"default": "x"}},
{"name": "sid", "in": "cookie", "required": False, "schema": {"default": "cookie1"}},
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"required": ["count"],
"properties": {
"count": {"type": "integer"},
"name": {"type": "string", "default": "n"},
},
}
}
}
},
}
tool = _build_tool(openapi=openapi)
tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_query_param": "key", "api_key_value": "v"}
headers = {}
captured = {}
def _fake_get(url, **kwargs):
captured["url"] = url
captured["kwargs"] = kwargs
return httpx.Response(200, text="ok")
monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.get", _fake_get)
response = tool.do_http_request(
"https://api.example.com/items/{id}",
"GET",
headers=headers,
parameters={"id": "123", "count": "2", "q": "search"},
)
assert isinstance(response, httpx.Response)
assert captured["url"].endswith("/items/123")
assert captured["kwargs"]["params"]["q"] == "search"
assert captured["kwargs"]["params"]["key"] == "v"
assert captured["kwargs"]["headers"]["Content-Type"] == "application/json"
invalid_method_tool = _build_tool(openapi={"parameters": []})
with pytest.raises(ValueError, match="Invalid http method"):
invalid_method_tool.do_http_request("https://api.example.com", "TRACE", headers={}, parameters={})
def test_do_http_request_handles_file_upload_and_invoke_paths(monkeypatch):
openapi = {
"parameters": [],
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"type": "object",
"properties": {"file": {"format": "binary"}},
}
}
}
},
}
tool = _build_tool(openapi=openapi)
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "k"}
fake_file = SimpleNamespace(filename="a.txt", mime_type="text/plain")
captured = {}
def _fake_post(url, **kwargs):
captured["headers"] = kwargs["headers"]
captured["files"] = kwargs["files"]
return httpx.Response(200, text="ok")
monkeypatch.setattr("core.tools.custom_tool.tool.download", lambda _: b"file-bytes")
monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.post", _fake_post)
response = tool.do_http_request(
"https://api.example.com/upload",
"POST",
headers={},
parameters={"file": fake_file},
)
assert isinstance(response, httpx.Response)
assert "Content-Type" not in captured["headers"]
assert captured["files"][0][0] == "file"
# _invoke JSON path
monkeypatch.setattr(tool, "assembling_request", lambda parameters: {})
monkeypatch.setattr(tool, "do_http_request", lambda *args, **kwargs: httpx.Response(200, text='{"a":1}'))
monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse({"a": 1}, True))
messages = list(tool.invoke(user_id="u1", tool_parameters={}))
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.TEXT]
# _invoke text path
monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse("plain", False))
messages = list(tool.invoke(user_id="u1", tool_parameters={}))
assert len(messages) == 1
assert messages[0].message.text == "plain"

View File

@ -0,0 +1,75 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolProviderType
def _db_provider() -> SimpleNamespace:
bundle = ApiToolBundle(
server_url="https://api.example.com/items",
method="GET",
summary="List items",
operation_id="list_items",
parameters=[],
author="author",
openapi={"parameters": []},
)
return SimpleNamespace(
id="provider-id",
tenant_id="tenant-1",
name="provider-a",
description="desc",
icon="icon.svg",
user=SimpleNamespace(name="Alice"),
tools=[bundle],
)
def test_api_tool_provider_from_db_and_parse_tool_bundle():
controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_HEADER)
assert controller.provider_type == ToolProviderType.API
assert any(c.name == "api_key_value" for c in controller.entity.credentials_schema)
tool = controller._parse_tool_bundle(_db_provider().tools[0])
assert isinstance(tool, ApiTool)
assert tool.entity.identity.provider == "provider-id"
def test_api_tool_provider_from_db_query_auth_and_none_auth():
query_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_QUERY)
assert any(c.name == "api_key_query_param" for c in query_controller.entity.credentials_schema)
none_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE)
assert [c.name for c in none_controller.entity.credentials_schema] == ["auth_type"]
def test_api_tool_provider_load_get_tools_and_get_tool():
controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE)
loaded = controller.load_bundled_tools(_db_provider().tools)
assert len(loaded) == 1
assert isinstance(controller.get_tool("list_items"), ApiTool)
with pytest.raises(ValueError, match="not found"):
controller.get_tool("missing")
# Return cached tools without querying database.
cached = controller.get_tools("tenant-1")
assert len(cached) == 1
# Force DB fetch branch.
controller.tools = []
provider_with_tools = _db_provider()
with patch("core.tools.custom_tool.provider.db") as mock_db:
scalars_result = Mock()
scalars_result.all.return_value = [provider_with_tools]
mock_db.session.scalars.return_value = scalars_result
tools = controller.get_tools("tenant-1")
assert len(tools) == 1

View File

@ -0,0 +1,145 @@
"""Unit tests for DatasetRetrieverTool behavior and retrieval wiring."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
def _retrieve_config() -> DatasetRetrieveConfigEntity:
return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE)
def test_get_dataset_tools_returns_empty_for_empty_dataset_ids() -> None:
# Arrange
retrieve_config = _retrieve_config()
# Act
tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id="tenant",
dataset_ids=[],
retrieve_config=retrieve_config,
return_resource=False,
invoke_from=InvokeFrom.DEBUGGER,
hit_callback=Mock(),
user_id="u",
inputs={},
)
# Assert
assert tools == []
def test_get_dataset_tools_returns_empty_for_missing_retrieve_config() -> None:
# Arrange
dataset_ids = ["d1"]
# Act
tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id="tenant",
dataset_ids=dataset_ids,
retrieve_config=None, # type: ignore[arg-type]
return_resource=False,
invoke_from=InvokeFrom.DEBUGGER,
hit_callback=Mock(),
user_id="u",
inputs={},
)
# Assert
assert tools == []
def test_get_dataset_tools_builds_tool_and_restores_strategy() -> None:
# Arrange
retrieve_config = _retrieve_config()
retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}")
feature = Mock()
feature.to_dataset_retriever_tool.return_value = [retrieval_tool]
# Act
with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature):
tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id="tenant",
dataset_ids=["d1"],
retrieve_config=retrieve_config,
return_resource=True,
invoke_from=InvokeFrom.DEBUGGER,
hit_callback=Mock(),
user_id="u",
inputs={"x": 1},
)
# Assert
assert len(tools) == 1
assert tools[0].entity.identity.name == "dataset_tool"
assert retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
def _build_dataset_tool() -> tuple[DatasetRetrieverTool, SimpleNamespace]:
retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}")
feature = Mock()
feature.to_dataset_retriever_tool.return_value = [retrieval_tool]
with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature):
tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id="tenant",
dataset_ids=["d1"],
retrieve_config=_retrieve_config(),
return_resource=False,
invoke_from=InvokeFrom.DEBUGGER,
hit_callback=Mock(),
user_id="u",
inputs={},
)
return tools[0], retrieval_tool
def test_runtime_parameters_shape() -> None:
# Arrange
tool, _ = _build_dataset_tool()
# Act
params = tool.get_runtime_parameters()
# Assert
assert len(params) == 1
assert params[0].name == "query"
def test_empty_query_behavior() -> None:
# Arrange
tool, _ = _build_dataset_tool()
# Act
empty_query = list(tool.invoke(user_id="u", tool_parameters={}))
# Assert
assert len(empty_query) == 1
assert empty_query[0].message.text == "please input query"
def test_query_invocation_result() -> None:
# Arrange
tool, _ = _build_dataset_tool()
# Act
result = list(tool.invoke(user_id="u", tool_parameters={"query": "hello"}))
# Assert
assert len(result) == 1
assert result[0].message.text == "result:hello"
def test_validate_credentials() -> None:
# Arrange
tool, _ = _build_dataset_tool()
# Act
result = tool.validate_credentials(credentials={}, parameters={}, format_only=False)
# Assert
assert result is None

View File

@ -0,0 +1,150 @@
from __future__ import annotations
import base64
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp.types import (
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.tools.mcp_tool.tool import MCPTool
def _build_mcp_tool(*, with_output_schema: bool = True) -> MCPTool:
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="remote-tool",
label=I18nObject(en_US="remote-tool"),
provider="provider-id",
),
parameters=[],
output_schema={"type": "object"} if with_output_schema else {},
)
return MCPTool(
entity=entity,
runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER),
tenant_id="tenant-1",
icon="icon.svg",
server_url="https://mcp.example.com",
provider_id="provider-id",
headers={"x-auth": "token"},
)
def test_mcp_tool_provider_type_and_fork_runtime():
tool = _build_mcp_tool()
assert tool.tool_provider_type() == ToolProviderType.MCP
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
assert isinstance(forked, MCPTool)
assert forked.runtime.tenant_id == "tenant-2"
assert forked.provider_id == "provider-id"
def test_mcp_tool_text_and_json_processing_helpers():
tool = _build_mcp_tool()
json_messages = list(tool._process_text_content(TextContent(type="text", text='{"a": 1}')))
assert json_messages[0].type == ToolInvokeMessage.MessageType.JSON
plain_messages = list(tool._process_text_content(TextContent(type="text", text="not-json")))
assert plain_messages[0].type == ToolInvokeMessage.MessageType.TEXT
assert plain_messages[0].message.text == "not-json"
list_messages = list(tool._process_json_content([{"k": 1}, {"k": 2}]))
assert [m.type for m in list_messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.JSON]
mixed_list_messages = list(tool._process_json_list([{"k": 1}, 2]))
assert len(mixed_list_messages) == 1
assert mixed_list_messages[0].type == ToolInvokeMessage.MessageType.TEXT
primitive_messages = list(tool._process_json_content(123))
assert primitive_messages[0].message.text == "123"
def test_mcp_tool_usage_extraction_helpers():
usage = MCPTool._extract_usage_dict({"usage": {"total_tokens": 9}})
assert usage == {"total_tokens": 9}
usage = MCPTool._extract_usage_dict({"metadata": {"usage": {"prompt_tokens": 3, "completion_tokens": 2}}})
assert usage == {"prompt_tokens": 3, "completion_tokens": 2}
usage = MCPTool._extract_usage_dict({"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3})
assert usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
usage = MCPTool._extract_usage_dict({"nested": [{"deep": {"usage": {"total_tokens": 7}}}]})
assert usage == {"total_tokens": 7}
result_with_usage = CallToolResult(content=[], _meta={"usage": {"prompt_tokens": 1, "completion_tokens": 2}})
derived = MCPTool._derive_usage_from_result(result_with_usage)
assert derived.prompt_tokens == 1
assert derived.completion_tokens == 2
result_without_usage = CallToolResult(content=[], _meta=None)
derived = MCPTool._derive_usage_from_result(result_without_usage)
assert derived.total_tokens == 0
def test_mcp_tool_invoke_handles_content_types_and_structured_output():
tool = _build_mcp_tool()
img_data = base64.b64encode(b"img").decode()
blob_data = base64.b64encode(b"blob").decode()
result = CallToolResult(
content=[
TextContent(type="text", text='{"a": 1}'),
ImageContent(type="image", data=img_data, mimeType="image/png"),
EmbeddedResource(
type="resource",
resource=TextResourceContents(uri="file:///tmp/a.txt", text="embedded-text"),
),
EmbeddedResource(
type="resource",
resource=BlobResourceContents(
uri="file:///tmp/b.bin",
blob=blob_data,
mimeType="application/octet-stream",
),
),
],
structuredContent={"x": 1},
_meta={"usage": {"prompt_tokens": 2, "completion_tokens": 3}},
)
with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool.invoke(user_id="user-1", tool_parameters={"a": 1}))
types = [m.type for m in messages]
assert ToolInvokeMessage.MessageType.JSON in types
assert ToolInvokeMessage.MessageType.BLOB in types
assert ToolInvokeMessage.MessageType.TEXT in types
assert ToolInvokeMessage.MessageType.VARIABLE in types
assert tool.latest_usage.total_tokens == 5
def test_mcp_tool_invoke_raises_for_unsupported_embedded_resource():
tool = _build_mcp_tool()
# Use model_construct to bypass pydantic validation and force unsupported resource path.
bad_resource = EmbeddedResource.model_construct(type="resource", resource=object())
result = CallToolResult(content=[bad_resource], _meta=None)
with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result):
with pytest.raises(ToolInvokeError, match="Unsupported embedded resource type"):
list(tool.invoke(user_id="user-1", tool_parameters={}))
def test_mcp_tool_handle_none_parameter_filters_empty_values():
tool = _build_mcp_tool()
cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"})
assert cleaned == {"a": 1, "e": "ok"}

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from datetime import datetime
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
def _build_mcp_entity(*, icon: str = "icon.svg") -> MCPProviderEntity:
now = datetime.now()
return MCPProviderEntity(
id="db-id",
provider_id="provider-id",
name="MCP Provider",
tenant_id="tenant-1",
user_id="user-1",
server_url="https://mcp.example.com",
headers={},
timeout=30,
sse_read_timeout=300,
authed=False,
credentials={},
tools=[
{
"name": "remote-tool",
"description": "remote tool",
"inputSchema": {},
"outputSchema": {"type": "object"},
}
],
icon=icon,
created_at=now,
updated_at=now,
)
def test_mcp_tool_provider_controller_from_entity_and_get_tools():
entity = _build_mcp_entity()
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
controller = MCPToolProviderController.from_entity(entity)
assert controller.provider_type == ToolProviderType.MCP
tool = controller.get_tool("remote-tool")
assert isinstance(tool, MCPTool)
assert tool.tenant_id == "tenant-1"
tools = controller.get_tools()
assert len(tools) == 1
assert isinstance(tools[0], MCPTool)
with pytest.raises(ValueError, match="not found"):
controller.get_tool("missing")
def test_mcp_tool_provider_controller_from_entity_requires_icon():
entity = _build_mcp_entity(icon="")
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
with pytest.raises(ValueError, match="icon is required"):
MCPToolProviderController.from_entity(entity)
def test_mcp_tool_provider_controller_from_db_delegates_to_entity():
entity = _build_mcp_entity()
db_provider = Mock()
db_provider.to_entity.return_value = entity
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
controller = MCPToolProviderController.from_db(db_provider)
assert isinstance(controller, MCPToolProviderController)

View File

@ -0,0 +1,91 @@
from __future__ import annotations
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolParameter
from core.tools.plugin_tool.tool import PluginTool
def _build_plugin_tool(*, has_runtime_parameters: bool) -> PluginTool:
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="tool-a",
label=I18nObject(en_US="tool-a"),
provider="provider-a",
),
parameters=[
ToolParameter.get_simple_instance(
name="query",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
],
has_runtime_parameters=has_runtime_parameters,
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, credentials={"api_key": "x"})
return PluginTool(
entity=entity,
runtime=runtime,
tenant_id="tenant-1",
icon="icon.svg",
plugin_unique_identifier="plugin-uid",
)
def test_plugin_tool_invoke_and_fork_runtime():
tool = _build_plugin_tool(has_runtime_parameters=False)
manager = Mock()
manager.invoke.return_value = iter([tool.create_text_message("ok")])
with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager):
with patch(
"core.tools.plugin_tool.tool.convert_parameters_to_plugin_format",
return_value={"converted": 1},
):
messages = list(tool.invoke(user_id="user-1", tool_parameters={"raw": 1}))
assert [m.message.text for m in messages] == ["ok"]
manager.invoke.assert_called_once()
assert manager.invoke.call_args.kwargs["tool_parameters"] == {"converted": 1}
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
assert isinstance(forked, PluginTool)
assert forked.runtime.tenant_id == "tenant-2"
assert forked.plugin_unique_identifier == "plugin-uid"
def test_plugin_tool_get_runtime_parameters_branches():
tool = _build_plugin_tool(has_runtime_parameters=False)
assert tool.get_runtime_parameters() == tool.entity.parameters
tool = _build_plugin_tool(has_runtime_parameters=True)
cached = [
ToolParameter.get_simple_instance(
name="k",
llm_description="k",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
]
tool.runtime_parameters = cached
assert tool.get_runtime_parameters() == cached
tool.runtime_parameters = None
manager = Mock()
returned = [
ToolParameter.get_simple_instance(
name="dyn",
llm_description="dyn",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
]
manager.get_runtime_parameters.return_value = returned
with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager):
assert tool.get_runtime_parameters(conversation_id="c1", app_id="a1", message_id="m1") == returned
assert tool.runtime_parameters == returned

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from unittest.mock import Mock, patch
import pytest
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolProviderEntityWithPlugin,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
def _build_controller() -> PluginToolProviderController:
tool_entity = ToolEntity(
identity=ToolIdentity(
author="author",
name="tool-a",
label=I18nObject(en_US="tool-a"),
provider="provider-a",
),
parameters=[],
)
entity = ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author="author",
name="provider-a",
description=I18nObject(en_US="desc"),
icon="icon.svg",
label=I18nObject(en_US="Provider"),
),
credentials_schema=[],
plugin_id="plugin-id",
tools=[tool_entity],
)
return PluginToolProviderController(
entity=entity,
plugin_id="plugin-id",
plugin_unique_identifier="plugin-uid",
tenant_id="tenant-1",
)
def test_plugin_tool_provider_controller_basic_behaviors():
controller = _build_controller()
assert controller.provider_type == ToolProviderType.PLUGIN
tool = controller.get_tool("tool-a")
assert isinstance(tool, PluginTool)
assert tool.runtime.tenant_id == "tenant-1"
tools = controller.get_tools()
assert len(tools) == 1
assert isinstance(tools[0], PluginTool)
with pytest.raises(ValueError, match="not found"):
controller.get_tool("missing")
def test_validate_credentials_success():
controller = _build_controller()
manager = Mock()
manager.validate_provider_credentials.return_value = True
with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager):
controller._validate_credentials(user_id="u1", credentials={"api_key": "x"})
manager.validate_provider_credentials.assert_called_once_with(
tenant_id="tenant-1",
user_id="u1",
provider="provider-a",
credentials={"api_key": "x"},
)
def test_validate_credentials_failure():
controller = _build_controller()
manager = Mock()
manager.validate_provider_credentials.return_value = False
with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager):
with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"):
controller._validate_credentials(user_id="u1", credentials={"api_key": "x"})

View File

@ -0,0 +1,119 @@
"""Unit tests for core.tools.signature covering signing and verification invariants."""
from __future__ import annotations
from urllib.parse import parse_qs, urlparse
import pytest
from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature
def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x01" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120)
url = sign_tool_file("tool-file-id", ".png", for_external=False)
parsed = urlparse(url)
query = parse_qs(parsed.query)
timestamp = query["timestamp"][0]
nonce = query["nonce"][0]
sign = query["sign"][0]
assert parsed.scheme == "https"
assert parsed.netloc == "internal.example.com"
assert parsed.path == "/files/tools/tool-file-id.png"
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is True
def test_sign_tool_file_for_external_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x04" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120)
url = sign_tool_file("tool-file-id", ".png", for_external=True)
parsed = urlparse(url)
assert parsed.scheme == "https"
assert parsed.netloc == "files.example.com"
assert parsed.path == "/files/tools/tool-file-id.png"
def test_verify_tool_file_signature_rejects_invalid_sign(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10)
url = sign_tool_file("tool-file-id", ".txt")
parsed = urlparse(url)
query = parse_qs(parsed.query)
timestamp = query["timestamp"][0]
nonce = query["nonce"][0]
sign = query["sign"][0]
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, "bad-signature") is False
def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10)
url = sign_tool_file("tool-file-id", ".txt")
parsed = urlparse(url)
query = parse_qs(parsed.query)
timestamp = query["timestamp"][0]
nonce = query["nonce"][0]
sign = query["sign"][0]
monkeypatch.setattr("core.tools.signature.time.time", lambda: int(timestamp) + 99)
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False
def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
url = sign_upload_file("upload-id", ".png")
parsed = urlparse(url)
query = parse_qs(parsed.query)
assert parsed.netloc == "internal.example.com"
assert parsed.path == "/files/upload-id/image-preview"
assert query["timestamp"][0]
assert query["nonce"][0]
assert query["sign"][0]
def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
url = sign_upload_file("upload-id", ".png")
parsed = urlparse(url)
query = parse_qs(parsed.query)
assert parsed.netloc == "files.example.com"
assert parsed.path == "/files/upload-id/image-preview"
assert query["timestamp"][0]
assert query["nonce"][0]
assert query["sign"][0]

View File

@ -0,0 +1,280 @@
from __future__ import annotations
from collections.abc import Generator
from types import SimpleNamespace
from typing import Any
from unittest.mock import Mock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolInvokeMeta,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolParameterValidationError,
)
from core.tools.tool_engine import ToolEngine
class _DummyTool(Tool):
result: Any
raise_error: Exception | None
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
super().__init__(entity=entity, runtime=runtime)
self.result = [self.create_text_message("ok")]
self.raise_error = None
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
if self.raise_error:
raise self.raise_error
if isinstance(self.result, list | Generator):
yield from self.result
else:
yield self.result
def _build_tool(with_llm_parameter: bool = False) -> _DummyTool:
parameters = []
if with_llm_parameter:
parameters = [
ToolParameter.get_simple_instance(
name="query",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
]
entity = ToolEntity(
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
parameters=parameters,
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, runtime_parameters={"rt": 1})
return _DummyTool(entity=entity, runtime=runtime)
def test_convert_tool_response_to_str_and_extract_binary_messages():
tool = _build_tool()
messages = [
tool.create_text_message("hello"),
tool.create_link_message("https://example.com"),
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE,
message=ToolInvokeMessage.TextMessage(text="https://example.com/a.png"),
meta={"mime_type": "image/png"},
),
tool.create_json_message({"a": 1}),
tool.create_json_message({"a": 1}, suppress_output=True),
]
text = ToolEngine._convert_tool_response_to_str(messages)
assert "hello" in text
assert "result link: https://example.com." in text
assert '"a": 1' in text
blob_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.TextMessage(text="https://example.com/blob.bin"),
meta={"mime_type": "application/octet-stream"},
)
link_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text="https://example.com/file.pdf"),
meta={"mime_type": "application/pdf"},
)
binaries = list(ToolEngine._extract_tool_response_binary_and_text([messages[2], blob_message, link_message]))
assert [b.mimetype for b in binaries] == ["image/png", "application/octet-stream", "application/pdf"]
with pytest.raises(ValueError, match="missing meta data"):
list(
ToolEngine._extract_tool_response_binary_and_text(
[
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE,
message=ToolInvokeMessage.TextMessage(text="x"),
)
]
)
)
def test_create_message_files_and_invoke_generator():
binaries = [
ToolInvokeMessageBinary(mimetype="image/png", url="https://example.com/abc.png"),
ToolInvokeMessageBinary(mimetype="audio/wav", url="https://example.com/def.wav"),
]
created = []
def _message_file_factory(**kwargs):
obj = SimpleNamespace(id=f"mf-{len(created) + 1}", **kwargs)
created.append(obj)
return obj
with patch("core.tools.tool_engine.MessageFile", side_effect=_message_file_factory):
with patch("core.tools.tool_engine.db") as mock_db:
ids = ToolEngine._create_message_files(
tool_messages=binaries,
agent_message=SimpleNamespace(id="msg-1"),
invoke_from=InvokeFrom.DEBUGGER,
user_id="user-1",
)
assert ids == ["mf-1", "mf-2"]
assert mock_db.session.add.call_count == 2
mock_db.session.close.assert_called_once()
tool = _build_tool()
invoked = list(ToolEngine._invoke(tool, {"a": 1}, user_id="u"))
assert invoked[0].type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(invoked[-1], ToolInvokeMeta)
assert invoked[-1].error is None
def test_generic_invoke_success_and_error_paths():
tool = _build_tool()
callback = Mock()
callback.on_tool_execution.side_effect = lambda **kwargs: kwargs["tool_outputs"]
response = list(
ToolEngine.generic_invoke(
tool=tool,
tool_parameters={"x": 1},
user_id="u1",
workflow_tool_callback=callback,
workflow_call_depth=0,
conversation_id="c1",
app_id="a1",
message_id="m1",
)
)
assert response[0].message.text == "ok"
callback.on_tool_start.assert_called_once()
callback.on_tool_execution.assert_called_once()
tool.raise_error = RuntimeError("boom")
error_callback = Mock()
error_callback.on_tool_execution.side_effect = lambda **kwargs: list(kwargs["tool_outputs"])
with pytest.raises(RuntimeError, match="boom"):
list(
ToolEngine.generic_invoke(
tool=tool,
tool_parameters={"x": 1},
user_id="u1",
workflow_tool_callback=error_callback,
workflow_call_depth=0,
)
)
error_callback.on_tool_error.assert_called_once()
def test_agent_invoke_success():
tool = _build_tool(with_llm_parameter=True)
callback = Mock()
message = SimpleNamespace(id="m1", conversation_id="c1")
meta = ToolInvokeMeta.empty()
with patch.object(ToolEngine, "_invoke", return_value=iter([tool.create_text_message("ok"), meta])):
with patch(
"core.tools.tool_engine.ToolFileMessageTransformer.transform_tool_invoke_messages",
side_effect=lambda messages, **kwargs: messages,
):
with patch.object(ToolEngine, "_extract_tool_response_binary_and_text", return_value=iter([])):
with patch.object(ToolEngine, "_create_message_files", return_value=[]):
result_text, message_files, result_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters="hello",
user_id="u1",
tenant_id="tenant-1",
message=message,
invoke_from=InvokeFrom.DEBUGGER,
agent_tool_callback=callback,
)
assert result_text == "ok"
assert message_files == []
assert result_meta.error is None
callback.on_tool_start.assert_called_once()
callback.on_tool_end.assert_called_once()
def test_agent_invoke_param_validation_error():
tool = _build_tool(with_llm_parameter=True)
callback = Mock()
message = SimpleNamespace(id="m1", conversation_id="c1")
with patch.object(ToolEngine, "_invoke", side_effect=ToolParameterValidationError("bad-param")):
error_text, files, error_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters={"a": 1},
user_id="u1",
tenant_id="tenant-1",
message=message,
invoke_from=InvokeFrom.DEBUGGER,
agent_tool_callback=callback,
)
assert "tool parameters validation error" in error_text
assert files == []
assert error_meta.error
def test_agent_invoke_engine_meta_error():
tool = _build_tool(with_llm_parameter=True)
callback = Mock()
message = SimpleNamespace(id="m1", conversation_id="c1")
engine_error = ToolEngineInvokeError(ToolInvokeMeta.error_instance("meta failure"))
with patch.object(ToolEngine, "_invoke", side_effect=engine_error):
error_text, files, error_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters={"a": 1},
user_id="u1",
tenant_id="tenant-1",
message=message,
invoke_from=InvokeFrom.DEBUGGER,
agent_tool_callback=callback,
)
assert "meta failure" in error_text
assert files == []
assert error_meta.error == "meta failure"
def test_agent_invoke_tool_invoke_error():
tool = _build_tool(with_llm_parameter=True)
callback = Mock()
message = SimpleNamespace(id="m1", conversation_id="c1")
with patch.object(ToolEngine, "_invoke", side_effect=ToolInvokeError("invoke boom")):
error_text, files, _ = ToolEngine.agent_invoke(
tool=tool,
tool_parameters={"a": 1},
user_id="u1",
tenant_id="tenant-1",
message=message,
invoke_from=InvokeFrom.DEBUGGER,
agent_tool_callback=callback,
)
assert "tool invoke error" in error_text
assert files == []

View File

@ -0,0 +1,249 @@
"""Unit tests for `ToolFileManager` behavior.
Covers signing/verification, file persistence flows, and retrieval APIs with
mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to
avoid real IO.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import httpx
import pytest
from core.tools.tool_file_manager import ToolFileManager
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16)
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100)
url = ToolFileManager.sign_file("tf-1", ".png")
return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&"))
def _patch_session_factory(session: Mock):
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm)
def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
url = ToolFileManager.sign_file("tf-1", ".png")
assert "/files/tools/tf-1.png" in url
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True
def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False
def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0)
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100)
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False
def test_create_file_by_raw_stores_file_and_persists_record() -> None:
manager = ToolFileManager()
session = Mock()
session.refresh.side_effect = lambda model: setattr(model, "id", "tf-1")
def tool_file_factory(**kwargs):
return SimpleNamespace(**kwargs)
with (
patch("core.tools.tool_file_manager.storage") as storage,
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
patch("core.tools.tool_file_manager.guess_extension", return_value=".txt"),
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="abc")),
_patch_session_factory(session),
):
file_model = manager.create_file_by_raw(
user_id="u1",
tenant_id="t1",
conversation_id="c1",
file_binary=b"hello",
mimetype="text/plain",
filename="readme",
)
assert file_model.name.endswith(".txt")
storage.save.assert_called_once()
session.add.assert_called_once()
session.commit.assert_called_once()
session.refresh.assert_called_once_with(file_model)
def test_create_file_by_url_downloads_and_persists_record() -> None:
manager = ToolFileManager()
response = Mock()
response.content = b"binary"
response.headers = {"Content-Type": "application/octet-stream"}
response.raise_for_status.return_value = None
session = Mock()
def tool_file_factory(**kwargs):
return SimpleNamespace(**kwargs)
session.refresh.side_effect = lambda model: setattr(model, "id", "tf-2")
with (
patch("core.tools.tool_file_manager.storage") as storage,
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")),
_patch_session_factory(session),
patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response),
):
file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
assert file_model.file_key.startswith("tools/t1/")
storage.save.assert_called_once()
session.add.assert_called_once_with(file_model)
session.commit.assert_called_once()
session.refresh.assert_called_once_with(file_model)
def test_create_file_by_url_raises_on_timeout() -> None:
manager = ToolFileManager()
with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")):
with pytest.raises(ValueError, match="timeout when downloading file"):
manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
def test_get_file_binary_returns_none_when_not_found() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
# Act
with _patch_session_factory(session):
result = manager.get_file_binary("missing")
# Assert
assert result is None
def test_get_file_binary_returns_bytes_when_found() -> None:
# Arrange
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
storage.load_once.return_value = b"hello"
with _patch_session_factory(session):
result = manager.get_file_binary("id1")
# Assert
assert result == (b"hello", "text/plain")
def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = None
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
# Act
with _patch_session_factory(session):
result = manager.get_file_binary_by_message_file_id("mf-1")
# Assert
assert result is None
def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
# Arrange
manager = ToolFileManager()
message_file = SimpleNamespace(url=None)
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
# Act
with _patch_session_factory(session):
result = manager.get_file_binary_by_message_file_id("mf-1")
# Assert
assert result is None
def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
# Arrange
manager = ToolFileManager()
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = tool_file
session.query.side_effect = [first_query, second_query]
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
storage.load_once.return_value = b"img"
with _patch_session_factory(session):
result = manager.get_file_binary_by_message_file_id("mf-1")
# Assert
assert result == (b"img", "image/png")
def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
# Act
with _patch_session_factory(session):
stream, tool_file = manager.get_file_generator_by_tool_file_id("tool123")
# Assert
assert stream is None
assert tool_file is None
def test_get_file_generator_returns_stream_when_found() -> None:
# Arrange
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
stream = iter([b"a", b"b"])
storage.load_stream.return_value = stream
with (
_patch_session_factory(session),
patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"),
):
result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123")
assert list(result_stream) == [b"a", b"b"]
assert result_file == "validated-file"

View File

@ -0,0 +1,92 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import PropertyMock, patch
import pytest
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
return None
def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController:
controller = object.__new__(ApiToolProviderController)
controller.provider_id = provider_id
return controller
def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderController:
controller = object.__new__(WorkflowToolProviderController)
controller.provider_id = provider_id
return controller
def test_tool_label_manager_filter_tool_labels():
filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"])
assert set(filtered) == {"search", "news"}
assert len(filtered) == 2
def test_tool_label_manager_update_tool_labels_db():
controller = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
delete_query = mock_db.session.query.return_value.where.return_value
delete_query.delete.return_value = None
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
delete_query.delete.assert_called_once()
# only one valid unique label should be inserted.
assert mock_db.session.add.call_count == 1
mock_db.session.commit.assert_called_once()
def test_tool_label_manager_update_tool_labels_unsupported():
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type]
def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
with patch.object(
_ConcreteBuiltinToolProviderController,
"tool_labels",
new_callable=PropertyMock,
return_value=["search", "news"],
):
builtin = object.__new__(_ConcreteBuiltinToolProviderController)
assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"]
api = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = ["search", "news"]
labels = ToolLabelManager.get_tool_labels(api)
assert labels == ["search", "news"]
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type]
def test_tool_label_manager_get_tools_labels_batch():
assert ToolLabelManager.get_tools_labels([]) == {}
api = _api_controller("api-1")
wf = _workflow_controller("wf-1")
records = [
SimpleNamespace(tool_id="api-1", label_name="search"),
SimpleNamespace(tool_id="api-1", label_name="news"),
SimpleNamespace(tool_id="wf-1", label_name="utilities"),
]
with patch("core.tools.tool_label_manager.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = records
labels = ToolLabelManager.get_tools_labels([api, wf])
assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]}
with pytest.raises(ValueError, match="Unsupported tool type"):
ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item]

View File

@ -0,0 +1,899 @@
from __future__ import annotations
"""Unit tests for ToolManager behavior with mocked providers and collaborators."""
import json
import threading
from types import SimpleNamespace
from typing import Any
from unittest.mock import Mock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_manager import ToolManager
class _SimpleContextVar:
def __init__(self):
self._is_set = False
self._value: Any = None
def get(self):
if not self._is_set:
raise LookupError
return self._value
def set(self, value: Any):
self._value = value
self._is_set = True
def _cm(session: Any):
context = Mock()
context.__enter__ = Mock(return_value=session)
context.__exit__ = Mock(return_value=False)
return context
def _setup_list_providers_from_api_mocks(
monkeypatch,
*,
session: Mock,
hardcoded_controller: SimpleNamespace,
plugin_controller: PluginToolProviderController,
api_controller: SimpleNamespace,
workflow_controller: SimpleNamespace,
):
mock_db = Mock()
mock_db.engine = object()
monkeypatch.setattr("core.tools.tool_manager.db", mock_db)
monkeypatch.setattr("core.tools.tool_manager.Session", lambda *args, **kwargs: _cm(session))
monkeypatch.setattr(
ToolManager,
"list_builtin_providers",
Mock(return_value=[hardcoded_controller, plugin_controller]),
)
monkeypatch.setattr(
ToolManager,
"list_default_builtin_providers",
Mock(return_value=[SimpleNamespace(provider="hardcoded")]),
)
monkeypatch.setattr("core.tools.tool_manager.is_filtered", lambda *args, **kwargs: False)
monkeypatch.setattr(
"core.tools.tool_manager.ToolTransformService.builtin_provider_to_user_provider",
lambda **kwargs: SimpleNamespace(name=kwargs["provider_controller"].entity.identity.name),
)
monkeypatch.setattr(
"core.tools.tool_manager.ToolTransformService.api_provider_to_controller",
Mock(side_effect=[api_controller, RuntimeError("invalid")]),
)
monkeypatch.setattr(
"core.tools.tool_manager.ToolTransformService.api_provider_to_user_provider",
Mock(return_value=SimpleNamespace(name="api-provider")),
)
monkeypatch.setattr(
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller",
Mock(side_effect=[workflow_controller, RuntimeError("deleted app")]),
)
monkeypatch.setattr(
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_user_provider",
Mock(return_value=SimpleNamespace(name="workflow-provider")),
)
monkeypatch.setattr(
"core.tools.tool_manager.ToolLabelManager.get_tools_labels",
Mock(side_effect=[{"api-1": ["search"]}, {"wf-1": ["utility"]}]),
)
mock_mcp_service = Mock()
mock_mcp_service.list_providers.return_value = [SimpleNamespace(name="mcp-provider")]
monkeypatch.setattr("core.tools.tool_manager.MCPToolManageService", Mock(return_value=mock_mcp_service))
monkeypatch.setattr("core.tools.tool_manager.BuiltinToolProviderSort.sort", lambda providers: providers)
@pytest.fixture(autouse=True)
def _reset_tool_manager_state():
old_hardcoded = ToolManager._hardcoded_providers.copy()
old_loaded = ToolManager._builtin_providers_loaded
old_labels = ToolManager._builtin_tools_labels.copy()
try:
yield
finally:
ToolManager._hardcoded_providers = old_hardcoded
ToolManager._builtin_providers_loaded = old_loaded
ToolManager._builtin_tools_labels = old_labels
def test_get_hardcoded_provider_loads_cache_when_empty():
provider = Mock()
ToolManager._hardcoded_providers = {}
def _load():
ToolManager._hardcoded_providers["weather"] = provider
with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load) as mock_load:
assert ToolManager.get_hardcoded_provider("weather") is provider
mock_load.assert_called_once()
def test_get_builtin_provider_returns_plugin_for_missing_hardcoded():
hardcoded = Mock()
plugin_provider = Mock()
ToolManager._hardcoded_providers = {"time": hardcoded}
with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider):
assert ToolManager.get_builtin_provider("time", "tenant-1") is hardcoded
assert ToolManager.get_builtin_provider("plugin/time", "tenant-1") is plugin_provider
def test_get_plugin_provider_uses_context_cache():
provider_context = _SimpleContextVar()
lock_context = _SimpleContextVar()
lock_context.set(threading.Lock())
provider_entity = SimpleNamespace(declaration=Mock(), plugin_id="pid", plugin_unique_identifier="uid")
with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context):
with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context):
with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls:
mock_manager_cls.return_value.fetch_tool_provider.return_value = provider_entity
controller = SimpleNamespace(name="controller")
with patch("core.tools.tool_manager.PluginToolProviderController", return_value=controller):
built = ToolManager.get_plugin_provider("provider-a", "tenant-1")
cached = ToolManager.get_plugin_provider("provider-a", "tenant-1")
assert built is controller
assert cached is controller
mock_manager_cls.return_value.fetch_tool_provider.assert_called_once()
def test_get_plugin_provider_raises_when_provider_missing():
provider_context = _SimpleContextVar()
lock_context = _SimpleContextVar()
lock_context.set(threading.Lock())
with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context):
with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context):
with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls:
mock_manager_cls.return_value.fetch_tool_provider.return_value = None
with pytest.raises(ToolProviderNotFoundError, match="plugin provider provider-a not found"):
ToolManager.get_plugin_provider("provider-a", "tenant-1")
def test_get_tool_runtime_builtin_without_credentials():
tool = Mock()
tool.fork_tool_runtime.return_value = "runtime-tool"
controller = SimpleNamespace(get_tool=Mock(return_value=tool), need_credentials=False)
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
result = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="time",
tool_name="current_time",
tenant_id="tenant-1",
)
assert result == "runtime-tool"
runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"]
assert runtime.tenant_id == "tenant-1"
assert runtime.credentials == {}
def test_get_tool_runtime_builtin_missing_tool_raises():
controller = SimpleNamespace(get_tool=Mock(return_value=None), need_credentials=False)
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
with pytest.raises(ToolProviderNotFoundError, match="builtin tool missing not found"):
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="time",
tool_name="missing",
tenant_id="tenant-1",
)
def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
tool = Mock()
tool.fork_tool_runtime.return_value = "runtime-tool"
controller = SimpleNamespace(
get_tool=Mock(return_value=tool),
need_credentials=True,
get_credentials_schema_by_type=Mock(return_value=[]),
)
builtin_provider = SimpleNamespace(
id="cred-1",
credential_type=CredentialType.API_KEY.value,
credentials={"encrypted": "value"},
expires_at=-1,
user_id="user-1",
)
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
builtin_provider
)
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key": "secret"}
cache = Mock()
with patch("core.tools.tool_manager.create_provider_encrypter", return_value=(encrypter, cache)):
result = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="time",
tool_name="weekday",
tenant_id="tenant-1",
)
assert result == "runtime-tool"
runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"]
assert runtime.credentials == {"api_key": "secret"}
assert runtime.credential_type == CredentialType.API_KEY
@patch("core.tools.tool_manager.create_provider_encrypter")
@patch("core.plugin.impl.oauth.OAuthHandler")
@patch(
"services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client",
return_value={"client_id": "id"},
)
@patch("core.tools.tool_manager.db")
@patch("core.tools.tool_manager.time.time", return_value=1000)
@patch("core.helper.credential_utils.check_credential_policy_compliance")
def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
mock_check,
mock_time,
mock_db,
mock_get_oauth_client,
mock_oauth_handler_cls,
mock_create_provider_encrypter,
):
tool = Mock()
tool.fork_tool_runtime.return_value = "runtime-tool"
controller = SimpleNamespace(
get_tool=Mock(return_value=tool),
need_credentials=True,
get_credentials_schema_by_type=Mock(return_value=[]),
)
builtin_provider = SimpleNamespace(
id="cred-1",
credential_type=CredentialType.OAUTH2.value,
credentials={"encrypted": "value"},
encrypted_credentials=None,
expires_at=1,
user_id="user-1",
)
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
encrypter = Mock()
encrypter.decrypt.return_value = {"token": "old"}
encrypter.encrypt.return_value = {"token": "encrypted"}
cache = Mock()
mock_create_provider_encrypter.return_value = (encrypter, cache)
mock_oauth_handler_cls.return_value.refresh_credentials.return_value = refreshed
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
result = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="time",
tool_name="weekday",
tenant_id="tenant-1",
)
assert result == "runtime-tool"
assert builtin_provider.expires_at == refreshed.expires_at
assert builtin_provider.encrypted_credentials == json.dumps({"token": "encrypted"})
mock_db.session.commit.assert_called_once()
cache.delete.assert_called_once()
def test_get_tool_runtime_builtin_plugin_provider_deleted_raises():
plugin_controller = object.__new__(PluginToolProviderController)
plugin_controller.entity = SimpleNamespace(credentials_schema=[{"name": "k"}], oauth_schema=None)
plugin_controller.get_tool = Mock(return_value=Mock())
plugin_controller.get_credentials_schema_by_type = Mock(return_value=[])
with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_controller):
with patch("core.tools.tool_manager.is_valid_uuid", return_value=True):
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = None
with pytest.raises(ToolProviderNotFoundError, match="provider has been deleted"):
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="time",
tool_name="weekday",
tenant_id="tenant-1",
credential_id="uuid-id",
)
def test_get_tool_runtime_api_path():
api_tool = Mock()
api_tool.fork_tool_runtime.return_value = "api-runtime"
api_provider = Mock()
api_provider.get_tool.return_value = api_tool
with patch.object(ToolManager, "get_api_provider_controller", return_value=(api_provider, {"c": "enc"})):
encrypter = Mock()
encrypter.decrypt.return_value = {"c": "dec"}
with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())):
assert (
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
)
== "api-runtime"
)
def test_get_tool_runtime_workflow_path():
workflow_provider = SimpleNamespace(tenant_id="tenant-1")
workflow_tool = Mock()
workflow_tool.fork_tool_runtime.return_value = "wf-runtime"
workflow_controller = Mock()
workflow_controller.get_tools.return_value = [workflow_tool]
session = Mock()
session.begin.return_value = _cm(None)
session.scalar.return_value = workflow_provider
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
with patch(
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller",
return_value=workflow_controller,
):
assert (
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.WORKFLOW,
provider_id="wf-1",
tool_name="wf",
tenant_id="tenant-1",
)
== "wf-runtime"
)
def test_get_tool_runtime_plugin_path():
with patch.object(
ToolManager,
"get_plugin_provider",
return_value=SimpleNamespace(get_tool=lambda _: "plugin-tool"),
):
assert (
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.PLUGIN,
provider_id="plugin-1",
tool_name="p",
tenant_id="tenant-1",
)
== "plugin-tool"
)
def test_get_tool_runtime_mcp_path():
with patch.object(
ToolManager,
"get_mcp_provider_controller",
return_value=SimpleNamespace(get_tool=lambda _: "mcp-tool"),
):
assert (
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.MCP,
provider_id="mcp-1",
tool_name="m",
tenant_id="tenant-1",
)
== "mcp-tool"
)
def test_get_tool_runtime_app_not_implemented():
with pytest.raises(NotImplementedError, match="app provider not implemented"):
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.APP,
provider_id="app",
tool_name="x",
tenant_id="tenant-1",
)
def test_get_agent_runtime_apply_runtime_parameters():
parameter = ToolParameter.get_simple_instance(
name="query",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
parameter.form = ToolParameter.ToolParameterForm.FORM
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager):
agent_tool = SimpleNamespace(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tool_parameters={"query": "hello"},
credential_id=None,
)
result = ToolManager.get_agent_tool_runtime(
tenant_id="tenant-1",
app_id="app-1",
agent_tool=agent_tool,
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert result is tool_runtime
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
def test_get_workflow_runtime_apply_runtime_parameters():
parameter = ToolParameter.get_simple_instance(
name="query",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
parameter.form = ToolParameter.ToolParameterForm.FORM
workflow_tool = SimpleNamespace(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tool_configurations={"query": "hello"},
credential_id=None,
)
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager):
workflow_result = ToolManager.get_workflow_tool_runtime(
tenant_id="tenant-1",
app_id="app-1",
node_id="node-1",
workflow_tool=workflow_tool,
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert workflow_result is tool_runtime2
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
def test_get_agent_runtime_raises_when_runtime_missing():
tool_runtime = SimpleNamespace(runtime=None, get_merged_runtime_parameters=lambda: [])
agent_tool = SimpleNamespace(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tool_parameters={},
credential_id=None,
)
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={}):
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=Mock()):
with pytest.raises(ValueError, match="runtime not found"):
ToolManager.get_agent_tool_runtime(
tenant_id="tenant-1",
app_id="app-1",
agent_tool=agent_tool,
)
def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
form_param = ToolParameter.get_simple_instance(
name="q",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
form_param.form = ToolParameter.ToolParameterForm.FORM
llm_param = ToolParameter.get_simple_instance(
name="llm",
llm_description="llm",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
llm_param.form = ToolParameter.ToolParameterForm.LLM
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
result = ToolManager.get_tool_runtime_from_plugin(
tool_type=ToolProviderType.API,
tenant_id="tenant-1",
provider="api-1",
tool_name="search",
tool_parameters={"q": "hello", "llm": "ignore"},
)
assert result is tool_entity
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
def test_hardcoded_provider_icon_success():
provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg")))
with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider):
with patch("core.tools.tool_manager.path.exists", return_value=True):
with patch("core.tools.tool_manager.mimetypes.guess_type", return_value=("image/svg+xml", None)):
icon_path, mime = ToolManager.get_hardcoded_provider_icon("time")
assert icon_path.endswith("icon.svg")
assert mime == "image/svg+xml"
def test_hardcoded_provider_icon_missing_raises():
provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg")))
with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider):
with patch("core.tools.tool_manager.path.exists", return_value=False):
with pytest.raises(ToolProviderNotFoundError, match="icon not found"):
ToolManager.get_hardcoded_provider_icon("time")
def test_list_hardcoded_providers_cache_hit():
ToolManager._hardcoded_providers = {"p": Mock()}
ToolManager._builtin_providers_loaded = True
assert list(ToolManager.list_hardcoded_providers()) == list(ToolManager._hardcoded_providers.values())
def test_clear_hardcoded_providers_cache_resets():
ToolManager._hardcoded_providers = {"p": Mock()}
ToolManager._builtin_providers_loaded = True
ToolManager.clear_hardcoded_providers_cache()
assert ToolManager._hardcoded_providers == {}
assert ToolManager._builtin_providers_loaded is False
def test_list_hardcoded_providers_internal_loader():
good_provider = SimpleNamespace(
entity=SimpleNamespace(identity=SimpleNamespace(name="good")),
get_tools=lambda: [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="tool-a", label="A")))],
)
provider_class = Mock(return_value=good_provider)
with patch("core.tools.tool_manager.listdir", return_value=["good", "bad", "__skip"]):
with patch("core.tools.tool_manager.path.isdir", side_effect=lambda p: "good" in p or "bad" in p):
with patch(
"core.tools.tool_manager.load_single_subclass_from_source",
side_effect=[provider_class, RuntimeError("boom")],
):
ToolManager._hardcoded_providers = {}
ToolManager._builtin_tools_labels = {}
providers = list(ToolManager._list_hardcoded_providers())
assert providers == [good_provider]
assert ToolManager._hardcoded_providers["good"] is good_provider
assert ToolManager._builtin_tools_labels["tool-a"] == "A"
assert ToolManager._builtin_providers_loaded is True
def test_get_tool_label_loads_cache_and_handles_missing():
ToolManager._builtin_tools_labels = {}
def _load():
ToolManager._builtin_tools_labels["tool-a"] = "Label A"
with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load):
assert ToolManager.get_tool_label("tool-a") == "Label A"
assert ToolManager.get_tool_label("missing") is None
def test_list_default_builtin_providers_for_postgres_and_mysql():
provider_records = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
for scheme in ("postgresql", "mysql"):
session = Mock()
session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
session.query.return_value.where.return_value.all.return_value = provider_records
with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)):
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
providers = ToolManager.list_default_builtin_providers("tenant-1")
assert providers == provider_records
def test_list_providers_from_api_covers_builtin_api_workflow_and_mcp(monkeypatch):
hardcoded_controller = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="hardcoded")))
plugin_controller = object.__new__(PluginToolProviderController)
plugin_controller.entity = SimpleNamespace(identity=SimpleNamespace(name="plugin-provider"))
api_db_provider_good = SimpleNamespace(id="api-1")
api_db_provider_bad = SimpleNamespace(id="api-2")
api_controller = SimpleNamespace(provider_id="api-1")
workflow_db_provider_good = SimpleNamespace(id="wf-1")
workflow_db_provider_bad = SimpleNamespace(id="wf-2")
workflow_controller = SimpleNamespace(provider_id="wf-1")
session = Mock()
session.scalars.side_effect = [
SimpleNamespace(all=lambda: [api_db_provider_good, api_db_provider_bad]),
SimpleNamespace(all=lambda: [workflow_db_provider_good, workflow_db_provider_bad]),
]
_setup_list_providers_from_api_mocks(
monkeypatch,
session=session,
hardcoded_controller=hardcoded_controller,
plugin_controller=plugin_controller,
api_controller=api_controller,
workflow_controller=workflow_controller,
)
providers = ToolManager.list_providers_from_api(user_id="user-1", tenant_id="tenant-1", typ="")
names = {provider.name for provider in providers}
assert {"hardcoded", "plugin-provider", "api-provider", "workflow-provider", "mcp-provider"} <= names
def test_get_api_provider_controller_returns_controller_and_credentials():
provider = SimpleNamespace(
id="api-1",
tenant_id="tenant-1",
name="api-provider",
description="desc",
credentials={"auth_type": "api_key_query"},
credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}',
schema_type="openapi",
schema="schema",
tools=[],
icon='{"background": "#000", "content": "A"}',
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value = db_query
with patch(
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
) as mock_from_db:
built_controller, credentials = ToolManager.get_api_provider_controller("tenant-1", "api-1")
assert built_controller is controller
assert credentials == provider.credentials
mock_from_db.assert_called_with(provider, ApiProviderAuthType.API_KEY_QUERY)
controller.load_bundled_tools.assert_called_once_with(provider.tools)
def test_user_get_api_provider_masks_credentials_and_adds_labels():
provider = SimpleNamespace(
id="api-1",
tenant_id="tenant-1",
name="api-provider",
description="desc",
credentials={"auth_type": "api_key_query"},
credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}',
schema_type="openapi",
schema="schema",
tools=[],
icon='{"background": "#000", "content": "A"}',
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value = db_query
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key_value": "secret"}
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())):
with patch("core.tools.tool_manager.ToolLabelManager.get_tool_labels", return_value=["search"]):
user_payload = ToolManager.user_get_api_provider("api-provider", "tenant-1")
assert user_payload["credentials"]["api_key_value"] == "***"
assert user_payload["labels"] == ["search"]
def test_get_api_provider_controller_not_found_raises():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
ToolManager.get_api_provider_controller("tenant-1", "missing")
def test_get_mcp_provider_controller_returns_controller():
provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"})
controller = Mock()
session = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
mock_service = mock_service_cls.return_value
mock_service.get_provider.return_value = provider_entity
with patch("core.tools.tool_manager.MCPToolProviderController.from_db", return_value=controller):
built = ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1")
assert built is controller
def test_generate_mcp_tool_icon_url_returns_provider_icon():
provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"})
session = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
mock_service = mock_service_cls.return_value
mock_service.get_provider_entity.return_value = provider_entity
assert ToolManager.generate_mcp_tool_icon_url("tenant-1", "mcp-1") == provider_entity.provider_icon
def test_get_mcp_provider_controller_missing_raises():
session = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
mock_service_cls.return_value.get_provider.side_effect = ValueError("missing")
with pytest.raises(ToolProviderNotFoundError, match="mcp provider mcp-1 not found"):
ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1")
def test_generate_tool_icon_urls_for_builtin_and_plugin():
with patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "https://console.example.com"):
builtin_url = ToolManager.generate_builtin_tool_icon_url("time")
plugin_url = ToolManager.generate_plugin_tool_icon_url("tenant-1", "icon.svg")
assert builtin_url.endswith("/tool-provider/builtin/time/icon")
assert "/plugin/icon" in plugin_url
def test_generate_tool_icon_urls_for_workflow_and_api():
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = None
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
def test_get_tool_icon_for_builtin_provider_variants():
plugin_provider = object.__new__(PluginToolProviderController)
plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg"))
with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_provider):
with patch.object(ToolManager, "generate_plugin_tool_icon_url", return_value="plugin-icon"):
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "plugin-provider") == "plugin-icon"
with patch.object(ToolManager, "get_builtin_provider", return_value=SimpleNamespace()):
with patch.object(ToolManager, "generate_builtin_tool_icon_url", return_value="builtin-icon"):
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "time") == "builtin-icon"
def test_get_tool_icon_for_api_workflow_and_mcp():
with patch.object(ToolManager, "generate_api_tool_icon_url", return_value={"background": "#000"}):
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.API, "api-1") == {"background": "#000"}
with patch.object(ToolManager, "generate_workflow_tool_icon_url", return_value={"background": "#111"}):
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.WORKFLOW, "wf-1") == {"background": "#111"}
with patch.object(ToolManager, "generate_mcp_tool_icon_url", return_value={"background": "#222"}):
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.MCP, "mcp-1") == {"background": "#222"}
def test_get_tool_icon_plugin_error_returns_default():
plugin_provider = object.__new__(PluginToolProviderController)
plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg"))
with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider):
with patch.object(ToolManager, "generate_plugin_tool_icon_url", side_effect=RuntimeError("fail")):
icon = ToolManager.get_tool_icon("tenant-1", ToolProviderType.PLUGIN, "plugin-provider")
assert icon["background"] == "#252525"
def test_get_tool_icon_invalid_provider_type_raises():
with pytest.raises(ValueError, match="provider type"):
ToolManager.get_tool_icon("tenant-1", "invalid", "x") # type: ignore[arg-type]
def test_convert_tool_parameters_type_agent_and_workflow_branches():
file_param = ToolParameter.get_simple_instance(
name="file",
llm_description="file",
typ=ToolParameter.ToolParameterType.FILE,
required=True,
)
file_param.form = ToolParameter.ToolParameterForm.FORM
with pytest.raises(ValueError, match="file type parameter file not supported in agent"):
ToolManager._convert_tool_parameters_type(
parameters=[file_param],
variable_pool=None,
tool_configurations={"file": "x"},
typ="agent",
)
text_param = ToolParameter.get_simple_instance(
name="text",
llm_description="text",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
text_param.form = ToolParameter.ToolParameterForm.FORM
plain = ToolManager._convert_tool_parameters_type(
parameters=[text_param],
variable_pool=None,
tool_configurations={"text": "hello"},
typ="workflow",
)
assert plain == {"text": "hello"}
variable_pool = Mock()
variable_pool.get.return_value = SimpleNamespace(value="from-variable")
variable_pool.convert_template.return_value = SimpleNamespace(text="from-template")
mixed = ToolManager._convert_tool_parameters_type(
parameters=[text_param],
variable_pool=variable_pool,
tool_configurations={"text": {"type": "mixed", "value": "Hello {{name}}"}},
typ="workflow",
)
assert mixed == {"text": "from-template"}
variable = ToolManager._convert_tool_parameters_type(
parameters=[text_param],
variable_pool=variable_pool,
tool_configurations={"text": {"type": "variable", "value": ["sys", "query"]}},
typ="workflow",
)
assert variable == {"text": "from-variable"}
def test_convert_tool_parameters_type_constant_branch():
text_param = ToolParameter.get_simple_instance(
name="text",
llm_description="text",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
text_param.form = ToolParameter.ToolParameterForm.FORM
variable_pool = Mock()
constant = ToolManager._convert_tool_parameters_type(
parameters=[text_param],
variable_pool=variable_pool,
tool_configurations={"text": {"type": "constant", "value": "fixed"}},
typ="workflow",
)
assert constant == {"text": "fixed"}

View File

@ -0,0 +1,110 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
import pytest
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.errors import ToolProviderCredentialValidationError
class _DummyTool(Tool):
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("ok")
class _DummyController(ToolProviderController):
def get_tool(self, tool_name: str) -> Tool:
entity = ToolEntity(
identity=ToolIdentity(
author="author",
name=tool_name,
label=I18nObject(en_US=tool_name),
provider="provider",
),
parameters=[],
)
return _DummyTool(entity=entity, runtime=ToolRuntime(tenant_id="tenant"))
def _provider_identity() -> ToolProviderIdentity:
return ToolProviderIdentity(
author="author",
name="provider",
description=I18nObject(en_US="desc"),
icon="icon.svg",
label=I18nObject(en_US="Provider"),
)
def test_tool_provider_controller_get_credentials_schema_returns_deep_copy():
entity = ToolProviderEntity(
identity=_provider_identity(),
credentials_schema=[ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="api_key", required=False)],
)
controller = _DummyController(entity=entity)
schema = controller.get_credentials_schema()
schema[0].name = "changed"
assert controller.entity.credentials_schema[0].name == "api_key"
def test_tool_provider_controller_default_provider_type():
entity = ToolProviderEntity(identity=_provider_identity(), credentials_schema=[])
controller = _DummyController(entity=entity)
assert controller.provider_type == ToolProviderType.BUILT_IN
def test_validate_credentials_format_covers_required_default_and_type_rules():
select_options = [ProviderConfig.Option(value="opt-a", label=I18nObject(en_US="A"))]
entity = ToolProviderEntity(
identity=_provider_identity(),
credentials_schema=[
ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="required_text", required=True),
ProviderConfig(type=ProviderConfig.Type.SECRET_INPUT, name="secret", required=False),
ProviderConfig(type=ProviderConfig.Type.SELECT, name="choice", required=False, options=select_options),
ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="with_default", required=False, default="x"),
],
)
controller = _DummyController(entity=entity)
credentials = {"required_text": "value", "secret": None, "choice": "opt-a"}
controller.validate_credentials_format(credentials)
assert credentials["with_default"] == "x"
with pytest.raises(ToolProviderCredentialValidationError, match="not found"):
controller.validate_credentials_format({"required_text": "value", "unknown": "v"})
with pytest.raises(ToolProviderCredentialValidationError, match="is required"):
controller.validate_credentials_format({"secret": "s"})
with pytest.raises(ToolProviderCredentialValidationError, match="should be string"):
controller.validate_credentials_format({"required_text": 123}) # type: ignore[arg-type]
with pytest.raises(ToolProviderCredentialValidationError, match="should be one of"):
controller.validate_credentials_format({"required_text": "value", "choice": "opt-b"})

View File

@ -0,0 +1,148 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.utils.configuration import ToolParameterConfigurationManager
class _DummyTool(Tool):
runtime_overrides: list[ToolParameter]
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, runtime_overrides: list[ToolParameter]):
super().__init__(entity=entity, runtime=runtime)
self.runtime_overrides = runtime_overrides
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("ok")
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
return self.runtime_overrides
def _param(
name: str,
*,
typ: ToolParameter.ToolParameterType,
form: ToolParameter.ToolParameterForm,
required: bool = False,
) -> ToolParameter:
return ToolParameter(
name=name,
label=I18nObject(en_US=name),
placeholder=I18nObject(en_US=""),
human_description=I18nObject(en_US=""),
type=typ,
form=form,
required=required,
default=None,
)
def _build_manager() -> ToolParameterConfigurationManager:
base_params = [
_param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM),
_param("plain", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM),
]
runtime_overrides = [
_param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM),
_param("runtime_only", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM),
]
entity = ToolEntity(
identity=ToolIdentity(author="a", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
parameters=base_params,
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
tool = _DummyTool(entity=entity, runtime=runtime, runtime_overrides=runtime_overrides)
return ToolParameterConfigurationManager(
tenant_id="tenant-1",
tool_runtime=tool,
provider_name="provider-a",
provider_type=ToolProviderType.BUILT_IN,
identity_id="ID.1",
)
def test_merge_and_mask_parameters():
manager = _build_manager()
masked = manager.mask_tool_parameters({"secret": "abcdefghi", "plain": "x", "runtime_only": "y"})
assert masked["secret"] == "ab*****hi"
assert masked["plain"] == "x"
assert masked["runtime_only"] == "y"
def test_encrypt_tool_parameters():
manager = _build_manager()
with patch("core.tools.utils.configuration.encrypter.encrypt_token", return_value="enc"):
encrypted = manager.encrypt_tool_parameters({"secret": "raw", "plain": "x"})
assert encrypted["secret"] == "enc"
assert encrypted["plain"] == "x"
def test_decrypt_tool_parameters_cache_hit_and_miss():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = {"secret": "cached"}
assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"}
cache.set.assert_not_called()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = None
with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"):
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
assert decrypted["secret"] == "dec"
cache.set.assert_called_once()
def test_delete_tool_parameters_cache():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
manager.delete_tool_parameters_cache()
cache_cls.return_value.delete.assert_called_once()
def test_configuration_manager_decrypt_suppresses_errors():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = None
with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")):
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
# decryption failure is suppressed, original value is retained.
assert decrypted["secret"] == "enc"

View File

@ -1,10 +1,13 @@
import copy
from unittest.mock import patch
from types import SimpleNamespace
from typing import Any
from unittest.mock import Mock, patch
import pytest
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_encryption import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
# ---------------------------
@ -13,13 +16,13 @@ from core.helper.provider_encryption import ProviderConfigEncrypter
class NoopCache:
"""Simple cache stub: always returns None, does nothing for set/delete."""
def get(self):
def get(self) -> Any | None:
return None
def set(self, config):
def set(self, config: Any) -> None:
pass
def delete(self):
def delete(self) -> None:
pass
@ -179,3 +182,35 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
assert out["password"] == "ENC_ERR"
def test_create_tool_provider_encrypter_builds_cache_and_encrypter():
basic_config = BasicProviderConfig(name="key", type=BasicProviderConfig.Type.TEXT_INPUT)
credential_schema_item = SimpleNamespace(to_basic_provider_config=lambda: basic_config)
controller = SimpleNamespace(
provider_type=SimpleNamespace(value="builtin"),
entity=SimpleNamespace(identity=SimpleNamespace(name="provider-a")),
get_credentials_schema=lambda: [credential_schema_item],
)
cache_instance = Mock()
encrypter_instance = Mock()
with patch(
"core.tools.utils.encryption.SingletonProviderCredentialsCache", return_value=cache_instance
) as cache_cls:
with patch("core.tools.utils.encryption.ProviderConfigEncrypter", return_value=encrypter_instance) as enc_cls:
encrypter, cache = create_tool_provider_encrypter("tenant-1", controller)
assert encrypter is encrypter_instance
assert cache is cache_instance
cache_cls.assert_called_once_with(
tenant_id="tenant-1",
provider_type="builtin",
provider_identity="provider-a",
)
enc_cls.assert_called_once_with(
tenant_id="tenant-1",
config=[basic_config],
provider_config_cache=cache_instance,
)

View File

@ -0,0 +1,478 @@
from __future__ import annotations
import uuid
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from yaml import YAMLError
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.models.document import Document as RagDocument
from core.tools.utils.dataset_retriever import dataset_multi_retriever_tool as multi_retriever_module
from core.tools.utils.dataset_retriever import dataset_retriever_tool as single_retriever_module
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool as SingleDatasetRetrieverTool
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.utils.yaml_utils import _load_yaml_file, load_yaml_file_cached
def _retrieve_config() -> DatasetRetrieveConfigEntity:
return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE)
class _FakeFlaskApp:
def app_context(self):
return nullcontext()
class _ImmediateThread:
def __init__(self, target=None, kwargs=None, **_kwargs):
self._target = target
self._kwargs = kwargs or {}
def start(self):
if self._target is not None:
self._target(**self._kwargs)
def join(self):
return None
class _TestHitCallback(DatasetIndexToolCallbackHandler):
def __init__(self):
self.queries: list[tuple[str, str]] = []
self.documents: list[RagDocument] | None = None
self.resources = None
def on_query(self, query: str, dataset_id: str):
self.queries.append((query, dataset_id))
def on_tool_end(self, documents: list[RagDocument]):
self.documents = documents
def return_retriever_resource_info(self, resource):
self.resources = list(resource)
def test_remove_leading_symbols_preserves_markdown_link_and_strips_punctuation():
markdown = "[Example](https://example.com) content"
assert remove_leading_symbols(markdown) == markdown
assert remove_leading_symbols("...Hello world") == "Hello world"
def test_is_valid_uuid_handles_valid_invalid_and_empty_values():
assert is_valid_uuid(str(uuid.uuid4())) is True
assert is_valid_uuid("not-a-uuid") is False
assert is_valid_uuid("") is False
assert is_valid_uuid(None) is False
def test_load_yaml_file_valid(tmp_path):
valid_file = tmp_path / "valid.yaml"
valid_file.write_text("a: 1\nb: two\n", encoding="utf-8")
loaded = _load_yaml_file(file_path=str(valid_file))
assert loaded == {"a": 1, "b": "two"}
def test_load_yaml_file_missing(tmp_path):
with pytest.raises(FileNotFoundError):
_load_yaml_file(file_path=str(tmp_path / "missing.yaml"))
def test_load_yaml_file_invalid(tmp_path):
invalid_file = tmp_path / "invalid.yaml"
invalid_file.write_text("a: [1, 2\n", encoding="utf-8")
with pytest.raises(YAMLError):
_load_yaml_file(file_path=str(invalid_file))
def test_load_yaml_file_cached_hits(tmp_path):
valid_file = tmp_path / "valid.yaml"
valid_file.write_text("a: 1\nb: two\n", encoding="utf-8")
load_yaml_file_cached.cache_clear()
assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"}
assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"}
assert load_yaml_file_cached.cache_info().hits == 1
def test_single_dataset_retriever_from_dataset_builds_name_and_description():
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="Knowledge", description=None)
tool = SingleDatasetRetrieverTool.from_dataset(
dataset=dataset,
retrieve_config=_retrieve_config(),
return_resource=False,
retriever_from="prod",
inputs={},
)
assert tool.name == "dataset_dataset_1"
assert tool.description == "useful for when you want to answer queries about the Knowledge"
def test_single_dataset_retriever_external_run_returns_content_and_resources():
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
name="Knowledge Base",
provider="external",
indexing_technique="high_quality",
retrieval_model={},
)
callback = _TestHitCallback()
dataset_retrieval = Mock()
dataset_retrieval.get_metadata_filter_condition.return_value = (
{"dataset-1": ["doc-a"]},
{"logical_operator": "and"},
)
db_session = Mock()
db_session.scalar.return_value = dataset
external_documents = [
{"content": "first", "metadata": {"document_id": "doc-a"}, "score": 0.9, "title": "Doc A"},
{"content": "second", "metadata": {"document_id": "doc-b"}, "score": 0.8, "title": "Doc B"},
]
tool = SingleDatasetRetrieverTool(
tenant_id="tenant-1",
dataset_id="dataset-1",
retrieve_config=_retrieve_config(),
return_resource=True,
retriever_from="dev",
hit_callbacks=[callback],
inputs={"x": 1},
)
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
with patch.object(
single_retriever_module.ExternalDatasetService,
"fetch_external_knowledge_retrieval",
return_value=external_documents,
) as fetch_mock:
result = tool.run(query="hello")
assert result == "first\nsecond"
assert callback.queries == [("hello", "dataset-1")]
assert callback.resources is not None
resource_info = callback.resources
assert [item.position for item in resource_info] == [1, 2]
assert resource_info[0].dataset_id == "dataset-1"
fetch_mock.assert_called_once()
def test_single_dataset_retriever_returns_empty_when_metadata_filter_finds_no_documents():
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
name="Knowledge Base",
provider="internal",
indexing_technique="high_quality",
retrieval_model=None,
)
dataset_retrieval = Mock()
dataset_retrieval.get_metadata_filter_condition.return_value = ({"dataset-1": []}, {"logical_operator": "and"})
db_session = Mock()
db_session.scalar.return_value = dataset
tool = SingleDatasetRetrieverTool(
tenant_id="tenant-1",
dataset_id="dataset-1",
retrieve_config=_retrieve_config(),
return_resource=False,
retriever_from="prod",
hit_callbacks=[_TestHitCallback()],
inputs={},
)
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
with patch.object(single_retriever_module.RetrievalService, "retrieve") as retrieve_mock:
result = tool.run(query="hello")
assert result == ""
retrieve_mock.assert_not_called()
def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
name="Knowledge Base",
provider="internal",
indexing_technique="high_quality",
retrieval_model={
"search_method": "semantic_search",
"score_threshold_enabled": True,
"score_threshold": 0.2,
"reranking_enable": True,
"reranking_model": {"reranking_provider_name": "provider", "reranking_model_name": "model"},
"reranking_mode": "reranking_model",
"weights": {"vector_setting": {"vector_weight": 0.6}},
},
)
callback = _TestHitCallback()
dataset_retrieval = Mock()
dataset_retrieval.get_metadata_filter_condition.return_value = (None, None)
low_segment = SimpleNamespace(
id="seg-low",
dataset_id="dataset-1",
document_id="doc-low",
content="raw low",
answer="low answer",
hit_count=1,
word_count=10,
position=3,
index_node_hash="hash-low",
get_sign_content=lambda: "signed low",
)
high_segment = SimpleNamespace(
id="seg-high",
dataset_id="dataset-1",
document_id="doc-high",
content="raw high",
answer=None,
hit_count=9,
word_count=25,
position=1,
index_node_hash="hash-high",
get_sign_content=lambda: "signed high",
)
records = [
SimpleNamespace(segment=low_segment, score=0.2, summary="summary low"),
SimpleNamespace(segment=high_segment, score=0.9, summary=None),
]
documents = [
RagDocument(page_content="first", metadata={"doc_id": "node-low", "score": 0.2}),
RagDocument(page_content="second", metadata={"doc_id": "node-high", "score": 0.9}),
]
lookup_doc_low = SimpleNamespace(
id="doc-low", name="Document Low", data_source_type="upload_file", doc_metadata={"lang": "en"}
)
lookup_doc_high = SimpleNamespace(
id="doc-high", name="Document High", data_source_type="notion", doc_metadata={"lang": "fr"}
)
db_session = Mock()
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
tool = SingleDatasetRetrieverTool(
tenant_id="tenant-1",
dataset_id="dataset-1",
retrieve_config=_retrieve_config(),
return_resource=True,
retriever_from="dev",
hit_callbacks=[callback],
inputs={},
top_k=2,
)
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
with patch.object(single_retriever_module.RetrievalService, "retrieve", return_value=documents):
with patch.object(
single_retriever_module.RetrievalService,
"format_retrieval_documents",
return_value=records,
):
result = tool.run(query="hello")
assert result == "signed high\nsummary low\nquestion:signed low answer:low answer"
assert callback.documents == documents
assert callback.resources is not None
resource_info = callback.resources
assert [item.position for item in resource_info] == [1, 2]
assert resource_info[0].segment_id == "seg-high"
assert resource_info[0].hit_count == 9
assert resource_info[1].summary == "summary low"
assert resource_info[1].content == "question:raw low \nanswer:low answer"
def test_multi_dataset_retriever_from_dataset_sets_tool_name():
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=["dataset-1"],
tenant_id="tenant-1",
reranking_provider_name="provider",
reranking_model_name="model",
return_resource=False,
retriever_from="prod",
)
assert tool.name == "dataset_tenant_1"
def test_multi_dataset_retriever_retriever_returns_early_when_dataset_is_missing():
callback = _TestHitCallback()
all_documents: list[RagDocument] = []
db_session = Mock()
db_session.scalar.return_value = None
tool = DatasetMultiRetrieverTool(
tenant_id="tenant-1",
dataset_ids=["dataset-1"],
reranking_provider_name="provider",
reranking_model_name="model",
return_resource=False,
retriever_from="prod",
)
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
with patch.object(multi_retriever_module.RetrievalService, "retrieve") as retrieve_mock:
result = tool._retriever(
flask_app=_FakeFlaskApp(),
dataset_id="dataset-1",
query="hello",
all_documents=all_documents,
hit_callbacks=[callback],
)
assert result == []
assert all_documents == []
assert callback.queries == []
retrieve_mock.assert_not_called()
def test_multi_dataset_retriever_retriever_non_economy_uses_retrieval_model():
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
retrieval_model={
"search_method": "semantic_search",
"top_k": 6,
"score_threshold_enabled": True,
"score_threshold": 0.4,
"reranking_enable": False,
"reranking_mode": None,
"weights": {"balanced": True},
},
)
callback = _TestHitCallback()
documents = [RagDocument(page_content="retrieved", metadata={"doc_id": "node-1", "score": 0.4})]
all_documents: list[RagDocument] = []
db_session = Mock()
db_session.scalar.return_value = dataset
tool = DatasetMultiRetrieverTool(
tenant_id="tenant-1",
dataset_ids=["dataset-1"],
reranking_provider_name="provider",
reranking_model_name="model",
return_resource=False,
retriever_from="prod",
top_k=2,
)
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
with patch.object(multi_retriever_module.RetrievalService, "retrieve", return_value=documents) as retrieve_mock:
tool._retriever(
flask_app=_FakeFlaskApp(),
dataset_id="dataset-1",
query="hello",
all_documents=all_documents,
hit_callbacks=[callback],
)
assert all_documents == documents
assert callback.queries == [("hello", "dataset-1")]
retrieve_mock.assert_called_once_with(
retrieval_method="semantic_search",
dataset_id="dataset-1",
query="hello",
top_k=6,
score_threshold=0.4,
reranking_model=None,
reranking_mode="reranking_model",
weights={"balanced": True},
)
def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
callback = _TestHitCallback()
tool = DatasetMultiRetrieverTool(
tenant_id="tenant-1",
dataset_ids=["dataset-1", "dataset-2"],
reranking_provider_name="provider",
reranking_model_name="model",
return_resource=True,
retriever_from="dev",
hit_callbacks=[callback],
top_k=2,
score_threshold=0.1,
)
first_doc = RagDocument(page_content="first", metadata={"doc_id": "node-2", "score": 0.4})
second_doc = RagDocument(page_content="second", metadata={"doc_id": "node-1", "score": 0.9})
def fake_retriever(**kwargs):
if kwargs["dataset_id"] == "dataset-1":
kwargs["all_documents"].append(first_doc)
else:
kwargs["all_documents"].append(second_doc)
segment_for_node_2 = SimpleNamespace(
id="seg-2",
dataset_id="dataset-1",
document_id="doc-2",
index_node_id="node-2",
content="raw two",
answer="answer two",
hit_count=2,
word_count=20,
position=2,
index_node_hash="hash-2",
get_sign_content=lambda: "signed two",
)
segment_for_node_1 = SimpleNamespace(
id="seg-1",
dataset_id="dataset-2",
document_id="doc-1",
index_node_id="node-1",
content="raw one",
answer=None,
hit_count=7,
word_count=30,
position=1,
index_node_hash="hash-1",
get_sign_content=lambda: "signed one",
)
db_session = Mock()
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
db_session.query.return_value.filter_by.return_value.first.side_effect = [
SimpleNamespace(id="dataset-2", name="Dataset Two"),
SimpleNamespace(id="dataset-1", name="Dataset One"),
]
db_session.scalar.side_effect = [
SimpleNamespace(id="doc-1", name="Doc One", data_source_type="upload_file", doc_metadata={"p": 1}),
SimpleNamespace(id="doc-2", name="Doc Two", data_source_type="notion", doc_metadata={"p": 2}),
]
model_manager = Mock()
model_manager.get_model_instance.return_value = Mock()
rerank_runner = Mock()
rerank_runner.run.return_value = [second_doc, first_doc]
fake_current_app = SimpleNamespace(_get_current_object=lambda: _FakeFlaskApp())
with patch.object(tool, "_retriever", side_effect=fake_retriever) as retriever_mock:
with patch.object(multi_retriever_module, "current_app", fake_current_app):
with patch.object(multi_retriever_module.threading, "Thread", _ImmediateThread):
with patch.object(multi_retriever_module, "ModelManager", return_value=model_manager):
with patch.object(multi_retriever_module, "RerankModelRunner", return_value=rerank_runner):
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
result = tool.run(query="hello")
assert result == "signed one\nquestion:signed two answer:answer two"
assert retriever_mock.call_count == 2
assert callback.documents == [second_doc, first_doc]
assert callback.resources is not None
resource_info = callback.resources
assert [item.position for item in resource_info] == [1, 2]
assert resource_info[0].score == 0.9
assert resource_info[0].content == "raw one"
assert resource_info[1].score == 0.4
assert resource_info[1].content == "question:raw two \nanswer:answer two"

View File

@ -0,0 +1,158 @@
"""Unit tests for ModelInvocationUtils.
Covers success and error branches for ModelInvocationUtils, including
InvokeModelError and invoke error mappings for InvokeAuthorizationError,
InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, and
InvokeServerUnavailableError. Assumes mocked model instances and managers.
"""
from __future__ import annotations
from decimal import Decimal
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace:
model_type_instance = Mock()
model_type_instance.get_model_schema.return_value = (
SimpleNamespace(model_properties=schema or {}) if schema is not None else None
)
return SimpleNamespace(
provider="provider",
model="model-a",
model_name="model-a",
credentials={"api_key": "x"},
model_type_instance=model_type_instance,
get_llm_num_tokens=lambda prompt_messages: 5,
invoke_llm=Mock(),
)
@pytest.mark.parametrize(
("model_instance", "expected", "error_match"),
[
(None, None, "Model not found"),
(_mock_model_instance(schema=None), None, "No model schema found"),
(_mock_model_instance(schema={}), 2048, None),
(_mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 8192}), 8192, None),
],
ids=[
"missing-model",
"missing-schema",
"default-context-size",
"schema-context-size",
],
)
def test_get_max_llm_context_tokens_branches(model_instance, expected, error_match):
manager = Mock()
manager.get_default_model_instance.return_value = model_instance
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
if error_match:
with pytest.raises(InvokeModelError, match=error_match):
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
else:
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
def test_calculate_tokens_handles_missing_model():
manager = Mock()
manager.get_default_model_instance.return_value = None
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with pytest.raises(InvokeModelError, match="Model not found"):
ModelInvocationUtils.calculate_tokens("tenant", [])
def test_invoke_success_and_error_mappings():
model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048})
model_instance.invoke_llm.return_value = SimpleNamespace(
message=SimpleNamespace(content="ok"),
usage=SimpleNamespace(
completion_tokens=7,
completion_unit_price=Decimal("0.1"),
completion_price_unit=Decimal(1),
latency=0.3,
total_price=Decimal("0.7"),
currency="USD",
),
)
manager = Mock()
manager.get_default_model_instance.return_value = model_instance
class _ToolModelInvoke:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
response = ModelInvocationUtils.invoke(
user_id="u1",
tenant_id="tenant",
tool_type="builtin",
tool_name="tool-a",
prompt_messages=[],
)
assert response.message.content == "ok"
assert db_mock.session.add.call_count == 1
assert db_mock.session.commit.call_count == 2
@pytest.mark.parametrize(
("exc", "expected"),
[
(InvokeRateLimitError("rate"), "Invoke rate limit error"),
(InvokeBadRequestError("bad"), "Invoke bad request error"),
(InvokeConnectionError("conn"), "Invoke connection error"),
(InvokeAuthorizationError("auth"), "Invoke authorization error"),
(InvokeServerUnavailableError("down"), "Invoke server unavailable error"),
(RuntimeError("oops"), "Invoke error"),
],
ids=[
"rate-limit",
"bad-request",
"connection",
"authorization",
"server-unavailable",
"generic-error",
],
)
def test_invoke_error_mappings(exc, expected):
model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048})
model_instance.invoke_llm.side_effect = exc
manager = Mock()
manager.get_default_model_instance.return_value = model_instance
class _ToolModelInvoke:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
with pytest.raises(InvokeModelError, match=expected):
ModelInvocationUtils.invoke(
user_id="u1",
tenant_id="tenant",
tool_type="builtin",
tool_name="tool-a",
prompt_messages=[],
)

View File

@ -1,6 +1,12 @@
from json.decoder import JSONDecodeError
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from yaml import YAMLError
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
from core.tools.utils.parser import ApiBasedToolSchemaParser
@ -189,3 +195,225 @@ def test_parse_openapi_to_tool_bundle_default_value_type_casting(app):
available_param = params_by_name["available"]
assert available_param.type == "boolean"
assert available_param.default is True
def test_sanitize_default_value_and_type_detection():
assert ApiBasedToolSchemaParser._sanitize_default_value([]) is None
assert ApiBasedToolSchemaParser._sanitize_default_value({}) is None
assert ApiBasedToolSchemaParser._sanitize_default_value("ok") == "ok"
assert (
ApiBasedToolSchemaParser._get_tool_parameter_type({"format": "binary"}) == ToolParameter.ToolParameterType.FILE
)
assert (
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "integer"}) == ToolParameter.ToolParameterType.NUMBER
)
assert (
ApiBasedToolSchemaParser._get_tool_parameter_type({"schema": {"type": "boolean"}})
== ToolParameter.ToolParameterType.BOOLEAN
)
assert (
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"format": "binary"}})
== ToolParameter.ToolParameterType.FILES
)
assert (
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"type": "string"}})
== ToolParameter.ToolParameterType.ARRAY
)
assert ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "object"}) is None
def test_parse_openapi_to_tool_bundle_server_env_and_refs(app):
openapi = {
"openapi": "3.0.0",
"info": {"title": "API", "version": "1.0.0", "description": "API description"},
"servers": [
{"url": "https://dev.example.com", "env": "dev"},
{"url": "https://prod.example.com", "env": "prod"},
],
"paths": {
"/items": {
"post": {
"description": "Create item",
"parameters": [
{"$ref": "#/components/parameters/token"},
{"name": "token", "schema": {"type": "string"}},
],
"requestBody": {
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/ItemRequest"}}}
},
}
}
},
"components": {
"parameters": {
"token": {"name": "token", "required": True, "schema": {"type": "string"}},
},
"schemas": {
"ItemRequest": {
"type": "object",
"required": ["age"],
"properties": {"age": {"type": "integer", "description": "Age", "default": 18}},
}
},
},
}
extra_info: dict = {}
warning: dict = {}
with app.test_request_context(headers={"X-Request-Env": "prod"}):
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
assert len(bundles) == 1
assert bundles[0].server_url == "https://prod.example.com/items"
assert warning["duplicated_parameter"].startswith("Parameter token")
assert extra_info["description"] == "API description"
def test_parse_openapi_to_tool_bundle_no_server_raises(app):
openapi = {"info": {"title": "x"}, "servers": [], "paths": {}}
with app.test_request_context():
with pytest.raises(ToolProviderNotFoundError, match="No server found"):
ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
def test_parse_openapi_yaml_to_tool_bundle_invalid_yaml(app):
with app.test_request_context():
with pytest.raises(ToolApiSchemaError, match="Invalid openapi yaml"):
ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle("null")
def test_parse_swagger_to_openapi_branches():
with pytest.raises(ToolApiSchemaError, match="No server found"):
ApiBasedToolSchemaParser.parse_swagger_to_openapi({"info": {}, "paths": {}})
with pytest.raises(ToolApiSchemaError, match="No paths found"):
ApiBasedToolSchemaParser.parse_swagger_to_openapi({"servers": [{"url": "https://x"}], "paths": {}})
with pytest.raises(ToolApiSchemaError, match="No operationId found"):
ApiBasedToolSchemaParser.parse_swagger_to_openapi(
{
"servers": [{"url": "https://x"}],
"paths": {"/a": {"get": {"summary": "x", "responses": {}}}},
}
)
warning: dict = {"seed": True}
converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
{
"servers": [{"url": "https://x"}],
"paths": {"/a": {"get": {"operationId": "getA", "responses": {}}}},
"definitions": {"A": {"type": "object"}},
},
warning=warning,
)
assert converted["openapi"] == "3.0.0"
assert converted["components"]["schemas"]["A"]["type"] == "object"
assert warning["missing_summary"].startswith("No summary or description found")
def test_parse_openai_plugin_json_branches(app):
with app.test_request_context():
with pytest.raises(ToolProviderNotFoundError, match="Invalid openai plugin json"):
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle("{bad")
with pytest.raises(ToolNotSupportedError, match="Only openapi is supported"):
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
'{"api": {"url": "https://x", "type": "graphql"}}'
)
def test_parse_openai_plugin_json_http_branches(app):
with app.test_request_context():
response = type("Resp", (), {"status_code": 500, "text": "", "close": Mock()})()
with patch("core.tools.utils.parser.httpx.get", return_value=response):
with pytest.raises(ToolProviderNotFoundError, match="cannot get openapi yaml"):
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
'{"api": {"url": "https://x", "type": "openapi"}}'
)
response.close.assert_called_once()
success_response = type("Resp", (), {"status_code": 200, "text": "openapi: 3.0.0", "close": Mock()})()
with patch("core.tools.utils.parser.httpx.get", return_value=success_response):
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle",
return_value=["bundle"],
) as mock_parse:
bundles = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
'{"api": {"url": "https://x", "type": "openapi"}}'
)
assert bundles == ["bundle"]
mock_parse.assert_called_once()
success_response.close.assert_called_once()
def test_auto_parse_json_yaml_failure():
with patch("core.tools.utils.parser.json_loads", side_effect=JSONDecodeError("bad", "x", 0)):
with patch("core.tools.utils.parser.safe_load", side_effect=YAMLError("bad yaml")):
with pytest.raises(ToolApiSchemaError, match="Invalid api schema, schema is neither json nor yaml"):
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(":::")
def test_auto_parse_openapi_success():
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
return_value=["openapi-bundle"],
):
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
assert bundles == ["openapi-bundle"]
assert schema_type == ApiProviderSchemaType.OPENAPI
def test_auto_parse_openapi_then_swagger():
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
loaded_content = {
"openapi": "3.0.0",
"servers": [{"url": "https://x"}],
"info": {"title": "x"},
"paths": {},
}
converted_swagger = {
"openapi": "3.0.0",
"servers": [{"url": "https://x"}],
"info": {"title": "x"},
"paths": {},
}
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
side_effect=[ToolApiSchemaError("openapi error"), ["swagger-bundle"]],
) as mock_parse_openapi:
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi",
return_value=converted_swagger,
) as mock_parse_swagger:
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
assert bundles == ["swagger-bundle"]
assert schema_type == ApiProviderSchemaType.SWAGGER
mock_parse_swagger.assert_called_once_with(loaded_content, extra_info={}, warning={})
assert mock_parse_openapi.call_count == 2
mock_parse_openapi.assert_any_call(loaded_content, extra_info={}, warning={})
mock_parse_openapi.assert_any_call(converted_swagger, extra_info={}, warning={})
def test_auto_parse_openapi_swagger_then_plugin():
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
side_effect=ToolApiSchemaError("openapi error"),
):
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi",
side_effect=ToolApiSchemaError("swagger error"),
):
with patch(
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle",
return_value=["plugin-bundle"],
):
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
assert bundles == ["plugin-bundle"]
assert schema_type == ApiProviderSchemaType.OPENAI_PLUGIN

View File

@ -0,0 +1,51 @@
from __future__ import annotations
import pytest
from core.tools.utils import system_oauth_encryption as oauth_encryption
from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter
def test_system_oauth_encrypter_roundtrip():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"}
encrypted = encrypter.encrypt_oauth_params(payload)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert encrypted
assert dict(decrypted) == payload
def test_system_oauth_encrypter_decrypt_validates_input():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
with pytest.raises(ValueError, match="must be a string"):
encrypter.decrypt_oauth_params(123) # type: ignore[arg-type]
with pytest.raises(ValueError, match="cannot be empty"):
encrypter.decrypt_oauth_params("")
def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext():
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
with pytest.raises(OAuthEncryptionError, match="Decryption failed"):
encrypter.decrypt_oauth_params("not-base64")
def test_system_oauth_helpers_use_global_cached_instance(monkeypatch):
monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None)
monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret")
first = oauth_encryption.get_system_oauth_encrypter()
second = oauth_encryption.get_system_oauth_encrypter()
assert first is second
encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"})
assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"}
def test_create_system_oauth_encrypter_factory():
encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret")
assert isinstance(encrypter, SystemOAuthEncrypter)

View File

@ -1,7 +1,9 @@
import pytest
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
@ -31,3 +33,91 @@ def test_ensure_no_human_input_nodes_raises_for_human_input():
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
def test_get_workflow_graph_variables_and_outputs():
graph = {
"nodes": [
{
"id": "start",
"data": {
"type": "start",
"variables": [
{
"variable": "query",
"label": "Query",
"type": "text-input",
"required": True,
}
],
},
},
{
"id": "end-1",
"data": {
"type": "end",
"outputs": [
{"variable": "answer", "value_type": "string", "value_selector": ["n1", "answer"]},
{"variable": "score", "value_type": "number", "value_selector": ["n1", "score"]},
],
},
},
{
"id": "end-2",
"data": {
"type": "end",
"outputs": [
{"variable": "answer", "value_type": "object", "value_selector": ["n2", "answer"]},
],
},
},
]
}
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
assert len(variables) == 1
assert variables[0].variable == "query"
assert variables[0].type == VariableEntityType.TEXT_INPUT
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
assert [output.variable for output in outputs] == ["answer", "score"]
assert outputs[0].value_type == "object"
assert outputs[1].value_type == "number"
no_start = WorkflowToolConfigurationUtils.get_workflow_graph_variables({"nodes": []})
assert no_start == []
def test_check_is_synced_validation():
variables = [
VariableEntity(
variable="query",
label="Query",
type=VariableEntityType.TEXT_INPUT,
required=True,
)
]
configs = [
WorkflowToolParameterConfiguration(
name="query",
description="desc",
form=ToolParameter.ToolParameterForm.FORM,
)
]
WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=configs)
with pytest.raises(ValueError, match="parameter configuration mismatch"):
WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=[])
with pytest.raises(ValueError, match="parameter configuration mismatch"):
WorkflowToolConfigurationUtils.check_is_synced(
variables=variables,
tool_configurations=[
WorkflowToolParameterConfiguration(
name="other",
description="desc",
form=ToolParameter.ToolParameterForm.FORM,
)
],
)

View File

@ -0,0 +1,196 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
def _controller() -> WorkflowToolProviderController:
entity = ToolProviderEntity(
identity=ToolProviderIdentity(
author="author",
name="wf-provider",
description=I18nObject(en_US="desc"),
icon="icon.svg",
label=I18nObject(en_US="WF"),
),
credentials_schema=[],
)
return WorkflowToolProviderController(entity=entity, provider_id="provider-1")
def _mock_session_with_begin() -> Mock:
session = Mock()
begin_cm = Mock()
begin_cm.__enter__ = Mock(return_value=None)
begin_cm.__exit__ = Mock(return_value=False)
session.begin.return_value = begin_cm
return session
def test_get_db_provider_tool_builds_entity():
controller = _controller()
session = Mock()
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
session.query.return_value.where.return_value.first.return_value = workflow
app = SimpleNamespace(id="app-1")
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
user_id="user-1",
parameter_configurations=[
SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM),
SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM),
],
)
user = SimpleNamespace(name="Alice")
variables = [
VariableEntity(
variable="country",
label="Country",
description="Country",
type=VariableEntityType.SELECT,
required=True,
options=["US", "IN"],
)
]
outputs = [
SimpleNamespace(variable="json", value_type="string"),
SimpleNamespace(variable="answer", value_type="string"),
]
with (
patch(
"core.tools.workflow_as_tool.provider.WorkflowAppConfigManager.convert_features",
return_value=SimpleNamespace(file_upload=True),
),
patch(
"core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_variables",
return_value=variables,
),
patch(
"core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_output",
return_value=outputs,
),
):
tool = controller._get_db_provider_tool(db_provider, app, session=session, user=user)
assert tool.entity.identity.name == "workflow_tool"
# "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out.
assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}}
assert "json" not in tool.entity.output_schema["properties"]
assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT
assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES
assert controller.provider_type == ToolProviderType.WORKFLOW
def test_get_tool_returns_hit_or_none():
controller = _controller()
tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool")))
controller.tools = [tool]
assert controller.get_tool("workflow_tool") is tool
assert controller.get_tool("missing") is None
def test_get_tools_returns_cached():
controller = _controller()
cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))]
controller.tools = cached_tools # type: ignore[assignment]
assert controller.get_tools("tenant-1") == cached_tools
def test_from_db_builds_controller():
controller = _controller()
app = SimpleNamespace(id="app-1")
user = SimpleNamespace(name="Alice")
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
user_id="user-1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
parameter_configurations=[],
)
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.get.side_effect = [app, user]
fake_cm = MagicMock()
fake_cm.__enter__.return_value = session
fake_cm.__exit__.return_value = False
fake_session_factory = Mock()
fake_session_factory.create_session.return_value = fake_cm
with patch("core.tools.workflow_as_tool.provider.session_factory", fake_session_factory):
with patch.object(
WorkflowToolProviderController,
"_get_db_provider_tool",
return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))),
):
built = WorkflowToolProviderController.from_db(db_provider)
assert isinstance(built, WorkflowToolProviderController)
assert built.tools
def test_get_tools_returns_empty_when_provider_missing():
controller = _controller()
controller.tools = None # type: ignore[assignment]
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = None
session_cls.return_value.__enter__.return_value = session
assert controller.get_tools("tenant-1") == []
def test_get_tools_raises_when_app_missing():
controller = _controller()
controller.tools = None # type: ignore[assignment]
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
user_id="user-1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
parameter_configurations=[],
)
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.get.return_value = None
session_cls.return_value.__enter__.return_value = session
with pytest.raises(ValueError, match="app not found"):
controller.get_tools("tenant-1")

View File

@ -1,20 +1,85 @@
"""Unit tests for workflow-as-tool behavior.
StubSession/StubScalars emulate SQLAlchemy session/scalars with minimal methods
(`scalar`, `scalars`, `expunge`, `commit`, `refresh`, context manager) to keep
database access mocked and predictable in tests.
"""
import json
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
from dify_graph.file import FILE_MODEL_IDENTITY
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
class StubScalars:
"""Minimal stub for SQLAlchemy scalar results."""
_value: Any
def __init__(self, value: Any) -> None:
self._value = value
def first(self) -> Any:
return self._value
class StubSession:
"""Minimal stub for session_factory-created sessions."""
scalar_results: list[Any]
scalars_results: list[Any]
expunge_calls: list[object]
def __init__(self, *, scalar_results: list[Any] | None = None, scalars_results: list[Any] | None = None) -> None:
self.scalar_results = list(scalar_results or [])
self.scalars_results = list(scalars_results or [])
self.expunge_calls: list[object] = []
def scalar(self, _stmt: Any) -> Any:
return self.scalar_results.pop(0)
def scalars(self, _stmt: Any) -> StubScalars:
return StubScalars(self.scalars_results.pop(0))
def expunge(self, value: Any) -> None:
self.expunge_calls.append(value)
def begin(self) -> "StubSession":
return self
def commit(self) -> None:
pass
def refresh(self, _value: Any) -> None:
pass
def close(self) -> None:
pass
def __enter__(self) -> "StubSession":
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
def _build_tool() -> WorkflowTool:
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
@ -22,9 +87,9 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
return WorkflowTool(
workflow_app_id="app-1",
workflow_as_tool_id="wf-tool-1",
version="1",
workflow_entities={},
workflow_call_depth=1,
@ -32,13 +97,19 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
runtime=runtime,
)
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
tool = _build_tool()
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
@ -56,28 +127,12 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
"""Ensure pause_state_config is passed as None."""
tool = _build_tool()
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
from unittest.mock import MagicMock, Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
@ -94,22 +149,7 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should generate variable messages when there are outputs"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
tool = _build_tool()
# Mock workflow outputs
mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}}
@ -119,8 +159,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
@ -134,10 +172,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
# Execute tool invocation
messages = list(tool.invoke("test_user", {}))
# Verify generated messages
# Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages
assert len(messages) == 5
# Verify variable messages
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
assert len(variable_messages) == 3
@ -151,7 +185,7 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
# Verify text message
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
assert len(text_messages) == 1
assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text
assert json.loads(text_messages[0].message.text) == mock_outputs
# Verify JSON message
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
@ -161,30 +195,13 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should handle empty outputs correctly"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
tool = _build_tool()
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
@ -217,61 +234,32 @@ def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPat
assert json_messages[0].message.json_object == {}
def test_create_variable_message():
"""Test the functionality of creating variable messages"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# Test different types of variable values
test_cases = [
@pytest.mark.parametrize(
("var_name", "var_value"),
[
("string_var", "test string"),
("int_var", 42),
("float_var", 3.14),
("bool_var", True),
("list_var", [1, 2, 3]),
("dict_var", {"key": "value"}),
]
],
)
def test_create_variable_message(var_name, var_value):
"""Create variable messages for multiple value types."""
tool = _build_tool()
for var_name, var_value in test_cases:
message = tool.create_variable_message(var_name, var_value)
message = tool.create_variable_message(var_name, var_value)
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
assert message.message.variable_name == var_name
assert message.message.variable_value == var_value
assert message.message.stream is False
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
assert message.message.variable_name == var_name
assert message.message.variable_value == var_value
assert message.message.stream is False
def test_create_file_message_should_include_file_marker():
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
"""Ensure file message includes marker and meta payload."""
tool = _build_tool()
file_obj = object()
message = tool.create_file_message(file_obj) # type: ignore[arg-type]
@ -284,103 +272,247 @@ def test_create_file_message_should_include_file_marker():
def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch):
"""Ensure worker context can resolve EndUser when Account is missing."""
class StubSession:
def __init__(self, results: list):
self.results = results
def scalar(self, _stmt):
return self.results.pop(0)
# SQLAlchemy Session APIs used by code under test
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
# support `with session_factory.create_session() as session:`
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
# Monkeypatch session factory to return our stub session
stub_session = StubSession(scalar_results=[tenant, None, end_user])
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([tenant, None, end_user]),
lambda: stub_session,
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
tool = _build_tool()
tool.runtime.invoke_from = InvokeFrom.SERVICE_API
tool.runtime.tenant_id = "tenant_id"
resolved_user = tool._resolve_user_from_database(user_id=end_user.id)
assert resolved_user is end_user
assert stub_session.expunge_calls == [end_user]
def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch):
"""Return None if tenant cannot be found in worker context."""
class StubSession:
def __init__(self, results: list):
self.results = results
def scalar(self, _stmt):
return self.results.pop(0)
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
# Monkeypatch session factory to return our stub session with no tenant
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([None]),
lambda: StubSession(scalar_results=[None]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
tool = _build_tool()
tool.runtime.invoke_from = InvokeFrom.SERVICE_API
tool.runtime.tenant_id = "missing_tenant"
resolved_user = tool._resolve_user_from_database(user_id="any")
assert resolved_user is None
def test_workflow_tool_provider_type_and_fork_runtime():
"""Verify provider type and forked runtime behavior."""
tool = _build_tool()
assert tool.tool_provider_type() == ToolProviderType.WORKFLOW
assert tool.latest_usage.total_tokens == 0
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2", invoke_from=InvokeFrom.DEBUGGER))
assert isinstance(forked, WorkflowTool)
assert forked.workflow_app_id == tool.workflow_app_id
assert forked.runtime.tenant_id == "tenant-2"
def test_derive_usage_from_top_level_usage_key():
"""Derive usage from top-level usage dict."""
usage = WorkflowTool._derive_usage_from_result({"usage": {"total_tokens": 12, "total_price": "0.2"}})
assert usage.total_tokens == 12
def test_derive_usage_from_metadata_usage():
"""Derive usage from metadata usage dict."""
metadata_usage = WorkflowTool._derive_usage_from_result({"metadata": {"usage": {"total_tokens": 7}}})
assert metadata_usage.total_tokens == 7
def test_derive_usage_from_totals():
"""Derive usage from top-level totals fields."""
totals_usage = WorkflowTool._derive_usage_from_result(
{"total_tokens": "9", "total_price": "1.3", "currency": "USD"}
)
assert totals_usage.total_tokens == 9
assert str(totals_usage.total_price) == "1.3"
def test_derive_usage_from_empty():
"""Default usage values when result is empty."""
empty_usage = WorkflowTool._derive_usage_from_result({})
assert empty_usage.total_tokens == 0
def test_extract_usage_from_nested():
"""Extract nested usage dict from result payloads."""
nested = WorkflowTool._extract_usage_dict({"nested": [{"data": {"usage": {"total_tokens": 3}}}]})
assert nested == {"total_tokens": 3}
def test_invoke_raises_when_user_not_found(monkeypatch: pytest.MonkeyPatch):
"""Raise ToolInvokeError when user resolution fails."""
tool = _build_tool()
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: None)
with pytest.raises(ToolInvokeError, match="User not found"):
list(tool.invoke("missing", {}))
def test_resolve_user_from_database_returns_account(monkeypatch: pytest.MonkeyPatch):
"""Resolve Account and set tenant in worker context."""
tenant = SimpleNamespace(id="tenant_id")
account = SimpleNamespace(id="account_id", current_tenant=None)
session = StubSession(scalar_results=[tenant, account])
monkeypatch.setattr("core.tools.workflow_as_tool.tool.session_factory.create_session", lambda: session)
tool = _build_tool()
tool.runtime.tenant_id = "tenant_id"
resolved = tool._resolve_user_from_database(user_id="account_id")
assert resolved is account
assert account.current_tenant is tenant
assert session.expunge_calls == [account]
def test_get_workflow_and_get_app_db_branches(monkeypatch: pytest.MonkeyPatch):
"""Cover workflow/app retrieval branches and error cases."""
tool = _build_tool()
latest_workflow = SimpleNamespace(id="wf-latest")
specific_workflow = SimpleNamespace(id="wf-v1")
app = SimpleNamespace(id="app-1")
sessions = iter(
[
StubSession(scalar_results=[], scalars_results=[latest_workflow]),
StubSession(scalar_results=[specific_workflow], scalars_results=[]),
StubSession(scalar_results=[app], scalars_results=[]),
]
)
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: next(sessions),
)
assert tool._get_workflow("app-1", "") is latest_workflow
assert tool._get_workflow("app-1", "1") is specific_workflow
assert tool._get_app("app-1") is app
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession(scalar_results=[None, None], scalars_results=[None]),
)
with pytest.raises(ValueError, match="workflow not found"):
tool._get_workflow("app-1", "1")
with pytest.raises(ValueError, match="app not found"):
tool._get_app("app-1")
def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool:
"""Build a WorkflowTool and stub merged runtime parameters for files/query."""
tool = _build_tool()
files_param = ToolParameter.get_simple_instance(
name="files",
llm_description="files",
typ=ToolParameter.ToolParameterType.SYSTEM_FILES,
required=False,
)
files_param.form = ToolParameter.ToolParameterForm.FORM
text_param = ToolParameter.get_simple_instance(
name="query",
llm_description="query",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
)
text_param.form = ToolParameter.ToolParameterForm.FORM
monkeypatch.setattr(tool, "get_merged_runtime_parameters", lambda: [files_param, text_param])
return tool
def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
"""Transform args into parameters and files payloads."""
tool = _setup_transform_args_tool(monkeypatch)
params, files = tool._transform_args(
{
"query": "hello",
"files": [
{
"tenant_id": "tenant-1",
"type": "image",
"transfer_method": "tool_file",
"related_id": "tool-1",
"extension": ".png",
},
{
"tenant_id": "tenant-1",
"type": "document",
"transfer_method": "local_file",
"related_id": "upload-1",
},
{
"tenant_id": "tenant-1",
"type": "document",
"transfer_method": "remote_url",
"remote_url": "https://example.com/a.pdf",
},
],
}
)
assert params == {"query": "hello"}
assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files)
assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files)
assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files)
def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch):
"""Ignore invalid file entries while keeping params."""
tool = _setup_transform_args_tool(monkeypatch)
invalid_params, invalid_files = tool._transform_args({"query": "hello", "files": [{"invalid": True}]})
assert invalid_params == {"query": "hello"}
assert invalid_files == []
def test_extract_files():
"""Extract file outputs into result and file list."""
tool = _build_tool()
built_files = [
SimpleNamespace(id="file-1"),
SimpleNamespace(id="file-2"),
]
with patch("core.tools.workflow_as_tool.tool.build_from_mapping", side_effect=built_files):
outputs = {
"attachments": [
{
"dify_model_identity": FILE_MODEL_IDENTITY,
"transfer_method": "tool_file",
"related_id": "r1",
}
],
"single_file": {
"dify_model_identity": FILE_MODEL_IDENTITY,
"transfer_method": "local_file",
"related_id": "r2",
},
"text": "ok",
}
result, extracted_files = tool._extract_files(outputs)
assert result["text"] == "ok"
assert len(extracted_files) == 2
def test_update_file_mapping():
"""Map tool/local file transfer methods into output shape."""
tool = _build_tool()
tool_file = tool._update_file_mapping({"transfer_method": "tool_file", "related_id": "tool-1"})
assert tool_file["tool_file_id"] == "tool-1"
local_file = tool._update_file_mapping({"transfer_method": "local_file", "related_id": "upload-1"})
assert local_file["upload_file_id"] == "upload-1"

View File