mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
test: Unit test cases for core.tools module (#32447)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: wangxiaolei <fatelei@gmail.com> Co-authored-by: akashseth-ifp <akash.seth@infocusp.com> Co-authored-by: mahammadasim <135003320+mahammadasim@users.noreply.github.com>
This commit is contained in:
parent
e99628b76f
commit
b170eabaf3
@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@ -137,6 +137,7 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
0
api/tests/unit_tests/core/tools/__init__.py
Normal file
0
api/tests/unit_tests/core/tools/__init__.py
Normal file
103
api/tests/unit_tests/core/tools/test_builtin_tool_base.py
Normal file
103
api/tests/unit_tests/core/tools/test_builtin_tool_base.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
|
||||
class _BuiltinDummyTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
def _build_tool() -> _BuiltinDummyTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
|
||||
parameters=[],
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
|
||||
|
||||
|
||||
def test_builtin_tool_fork_and_provider_type():
|
||||
tool = _build_tool()
|
||||
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
|
||||
assert isinstance(forked, _BuiltinDummyTool)
|
||||
assert forked.runtime.tenant_id == "tenant-2"
|
||||
assert tool.tool_provider_type() == ToolProviderType.BUILT_IN
|
||||
|
||||
|
||||
def test_invoke_model_calls_model_invocation_utils_invoke():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
|
||||
assert (
|
||||
tool.invoke_model(
|
||||
user_id="u1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
stop=[],
|
||||
)
|
||||
== "result"
|
||||
)
|
||||
mock_invoke.assert_called_once()
|
||||
|
||||
|
||||
def test_get_max_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
|
||||
assert tool.get_max_tokens() == 4096
|
||||
|
||||
|
||||
def test_get_prompt_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
|
||||
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
|
||||
|
||||
|
||||
def test_runtime_none_raises():
|
||||
tool = _build_tool()
|
||||
tool.runtime = None
|
||||
with pytest.raises(ValueError, match="runtime is required"):
|
||||
tool.get_max_tokens()
|
||||
with pytest.raises(ValueError, match="runtime is required"):
|
||||
tool.get_prompt_tokens([UserPromptMessage(content="hello")])
|
||||
|
||||
|
||||
def test_builtin_tool_summary_short_and_long_content_paths():
|
||||
tool = _build_tool()
|
||||
|
||||
with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=100):
|
||||
with patch.object(_BuiltinDummyTool, "get_prompt_tokens", return_value=10):
|
||||
assert tool.summary(user_id="u1", content="short") == "short"
|
||||
|
||||
with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=10):
|
||||
with patch.object(
|
||||
_BuiltinDummyTool,
|
||||
"get_prompt_tokens",
|
||||
side_effect=lambda prompt_messages: len(prompt_messages[-1].content),
|
||||
):
|
||||
with patch.object(
|
||||
_BuiltinDummyTool,
|
||||
"invoke_model",
|
||||
return_value=SimpleNamespace(message=SimpleNamespace(content="S")),
|
||||
):
|
||||
result = tool.summary(user_id="u1", content="x" * 30 + "\n" + "y" * 5)
|
||||
|
||||
assert result
|
||||
assert "S" in result
|
||||
216
api/tests/unit_tests/core/tools/test_builtin_tool_provider.py
Normal file
216
api/tests/unit_tests/core/tools/test_builtin_tool_provider.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
|
||||
|
||||
class _FakeBuiltinTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
class _ConcreteBuiltinProvider(BuiltinToolProviderController):
|
||||
last_validation: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
self.last_validation = (user_id, credentials)
|
||||
|
||||
|
||||
def _provider_yaml() -> dict[str, Any]:
|
||||
return {
|
||||
"identity": {
|
||||
"author": "Dify",
|
||||
"name": "fake_provider",
|
||||
"label": {"en_US": "Fake Provider"},
|
||||
"description": {"en_US": "Fake description"},
|
||||
"icon": "icon.svg",
|
||||
"tags": ["utilities"],
|
||||
},
|
||||
"credentials_for_provider": {
|
||||
"api_key": {
|
||||
"type": "secret-input",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
"oauth_schema": {
|
||||
"client_schema": [
|
||||
{
|
||||
"name": "client_id",
|
||||
"type": "text-input",
|
||||
}
|
||||
],
|
||||
"credentials_schema": [
|
||||
{
|
||||
"name": "access_token",
|
||||
"type": "secret-input",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _tool_yaml() -> dict[str, Any]:
|
||||
return {
|
||||
"identity": {
|
||||
"author": "Dify",
|
||||
"name": "tool_a",
|
||||
"label": {"en_US": "Tool A"},
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
|
||||
|
||||
def test_builtin_tool_provider_init_load_tools_and_basic_accessors(monkeypatch):
|
||||
yaml_payloads = [_provider_yaml(), _tool_yaml()]
|
||||
|
||||
def _load_yaml(*args, **kwargs):
|
||||
return yaml_payloads.pop(0)
|
||||
|
||||
monkeypatch.setattr("core.tools.builtin_tool.provider.load_yaml_file_cached", _load_yaml)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.provider.listdir",
|
||||
lambda *args, **kwargs: ["tool_a.yaml", "__init__.py", "readme.md"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
|
||||
lambda *args, **kwargs: _FakeBuiltinTool,
|
||||
)
|
||||
provider = _ConcreteBuiltinProvider()
|
||||
|
||||
assert provider.get_credentials_schema()
|
||||
assert provider.get_tools()
|
||||
assert provider.get_tool("tool_a") is not None
|
||||
assert provider.get_tool("missing") is None
|
||||
assert provider.provider_type == ToolProviderType.BUILT_IN
|
||||
assert provider.tool_labels == ["utilities"]
|
||||
assert provider.need_credentials is True
|
||||
|
||||
oauth_schema = provider.get_credentials_schema_by_type(CredentialType.OAUTH2)
|
||||
assert len(oauth_schema) == 1
|
||||
api_schema = provider.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
assert len(api_schema) == 1
|
||||
assert provider.get_oauth_client_schema()[0].name == "client_id"
|
||||
assert set(provider.get_supported_credential_types()) == {CredentialType.API_KEY, CredentialType.OAUTH2}
|
||||
|
||||
|
||||
def test_builtin_tool_provider_invalid_credential_type_raises():
|
||||
with (
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_yaml_file_cached",
|
||||
side_effect=[_provider_yaml(), _tool_yaml()],
|
||||
),
|
||||
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
|
||||
return_value=_FakeBuiltinTool,
|
||||
),
|
||||
):
|
||||
provider = _ConcreteBuiltinProvider()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid credential type: invalid"):
|
||||
provider.get_credentials_schema_by_type("invalid")
|
||||
|
||||
|
||||
def test_builtin_tool_provider_validate_credentials_delegates():
|
||||
with (
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_yaml_file_cached",
|
||||
side_effect=[_provider_yaml(), _tool_yaml()],
|
||||
),
|
||||
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
|
||||
return_value=_FakeBuiltinTool,
|
||||
),
|
||||
):
|
||||
provider = _ConcreteBuiltinProvider()
|
||||
|
||||
provider.validate_credentials("user-1", {"api_key": "secret"})
|
||||
assert provider.last_validation == ("user-1", {"api_key": "secret"})
|
||||
|
||||
|
||||
def test_builtin_tool_provider_unauthorized_schema_is_empty():
|
||||
with (
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_yaml_file_cached",
|
||||
side_effect=[_provider_yaml(), _tool_yaml()],
|
||||
),
|
||||
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
|
||||
return_value=_FakeBuiltinTool,
|
||||
),
|
||||
):
|
||||
provider = _ConcreteBuiltinProvider()
|
||||
|
||||
assert provider.get_credentials_schema_by_type(CredentialType.UNAUTHORIZED) == []
|
||||
|
||||
|
||||
def test_builtin_tool_provider_init_raises_when_provider_yaml_missing():
|
||||
with patch("core.tools.builtin_tool.provider.load_yaml_file_cached", side_effect=RuntimeError("boom")):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="can not load provider yaml"):
|
||||
_ConcreteBuiltinProvider()
|
||||
|
||||
|
||||
def test_builtin_tool_provider_handles_empty_credentials_and_oauth():
|
||||
provider = object.__new__(_ConcreteBuiltinProvider)
|
||||
provider.tools = []
|
||||
provider.entity = ToolProviderEntity.model_validate(
|
||||
{
|
||||
"identity": {
|
||||
"author": "Dify",
|
||||
"name": "fake_provider",
|
||||
"label": {"en_US": "Fake Provider"},
|
||||
"description": {"en_US": "Fake description"},
|
||||
"icon": "icon.svg",
|
||||
"tags": None,
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"oauth_schema": None,
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.get_oauth_client_schema() == []
|
||||
assert provider.get_supported_credential_types() == []
|
||||
assert provider.need_credentials is False
|
||||
assert provider._get_tool_labels() == []
|
||||
|
||||
|
||||
def test_builtin_tool_provider_forked_tool_runtime_is_initialized():
|
||||
with (
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_yaml_file_cached",
|
||||
side_effect=[_provider_yaml(), _tool_yaml()],
|
||||
),
|
||||
patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]),
|
||||
patch(
|
||||
"core.tools.builtin_tool.provider.load_single_subclass_from_source",
|
||||
return_value=_FakeBuiltinTool,
|
||||
),
|
||||
):
|
||||
provider = _ConcreteBuiltinProvider()
|
||||
|
||||
tool = provider.get_tool("tool_a")
|
||||
assert tool is not None
|
||||
assert isinstance(tool.runtime, ToolRuntime)
|
||||
assert tool.runtime.tenant_id == ""
|
||||
tool.runtime.invoke_from = InvokeFrom.DEBUGGER
|
||||
assert tool.runtime.invoke_from == InvokeFrom.DEBUGGER
|
||||
310
api/tests/unit_tests/core/tools/test_builtin_tools_extra.py
Normal file
310
api/tests/unit_tests/core/tools/test_builtin_tools_extra.py
Normal file
@ -0,0 +1,310 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.providers.audio.audio import AudioToolProvider
|
||||
from core.tools.builtin_tool.providers.audio.tools.asr import ASRTool
|
||||
from core.tools.builtin_tool.providers.audio.tools.tts import TTSTool
|
||||
from core.tools.builtin_tool.providers.code.code import CodeToolProvider
|
||||
from core.tools.builtin_tool.providers.code.tools.simple_code import SimpleCode
|
||||
from core.tools.builtin_tool.providers.time.time import WikiPediaProvider
|
||||
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.builtin_tool.providers.time.tools.localtime_to_timestamp import LocaltimeToTimestampTool
|
||||
from core.tools.builtin_tool.providers.time.tools.timestamp_to_localtime import TimestampToLocaltimeTool
|
||||
from core.tools.builtin_tool.providers.time.tools.timezone_conversion import TimezoneConversionTool
|
||||
from core.tools.builtin_tool.providers.time.tools.weekday import WeekdayTool
|
||||
from core.tools.builtin_tool.providers.webscraper.tools.webscraper import WebscraperTool
|
||||
from core.tools.builtin_tool.providers.webscraper.webscraper import WebscraperProvider
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from dify_graph.file.enums import FileType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
|
||||
|
||||
def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="tool-a",
|
||||
label=I18nObject(en_US="tool-a"),
|
||||
provider="provider-a",
|
||||
),
|
||||
parameters=[],
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
return tool_cls(provider="provider-a", entity=entity, runtime=runtime)
|
||||
|
||||
|
||||
def _raise_runtime_error(*_args: object, **_kwargs: object) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_current_time_tool():
|
||||
current_tool = _build_builtin_tool(CurrentTimeTool)
|
||||
utc_text = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "UTC"}))[0].message.text
|
||||
assert utc_text
|
||||
|
||||
invalid_tz = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "Invalid/TZ"}))[0].message.text
|
||||
assert "Invalid timezone" in invalid_tz
|
||||
|
||||
|
||||
def test_localtime_to_timestamp_tool():
|
||||
localtime_tool = _build_builtin_tool(LocaltimeToTimestampTool)
|
||||
ts_message = list(
|
||||
localtime_tool.invoke(user_id="u", tool_parameters={"localtime": "2024-01-01 10:00:00", "timezone": "UTC"})
|
||||
)[0].message.text
|
||||
ts_value = float(ts_message.strip())
|
||||
assert math.isfinite(ts_value)
|
||||
assert ts_value >= 0
|
||||
with pytest.raises(ToolInvokeError):
|
||||
LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC")
|
||||
|
||||
|
||||
def test_timestamp_to_localtime_tool():
|
||||
to_local_tool = _build_builtin_tool(TimestampToLocaltimeTool)
|
||||
local_text = list(to_local_tool.invoke(user_id="u", tool_parameters={"timestamp": 1704067200, "timezone": "UTC"}))[
|
||||
0
|
||||
].message.text
|
||||
assert "2024" in local_text
|
||||
with pytest.raises(ToolInvokeError):
|
||||
TimestampToLocaltimeTool.timestamp_to_localtime("bad", "UTC") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_timezone_conversion_tool():
|
||||
timezone_tool = _build_builtin_tool(TimezoneConversionTool)
|
||||
converted = list(
|
||||
timezone_tool.invoke(
|
||||
user_id="u",
|
||||
tool_parameters={
|
||||
"current_time": "2024-01-01 08:00:00",
|
||||
"current_timezone": "UTC",
|
||||
"target_timezone": "Asia/Tokyo",
|
||||
},
|
||||
)
|
||||
)[0].message.text
|
||||
assert converted.startswith("2024-01-01")
|
||||
with pytest.raises(ToolInvokeError):
|
||||
TimezoneConversionTool.timezone_convert("bad", "UTC", "Asia/Tokyo")
|
||||
|
||||
|
||||
def test_weekday_tool():
|
||||
weekday_tool = _build_builtin_tool(WeekdayTool)
|
||||
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
|
||||
assert "January 1, 2024" in valid
|
||||
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
|
||||
0
|
||||
].message.text
|
||||
assert "Invalid date" in invalid
|
||||
with pytest.raises(ValueError, match="Month is required"):
|
||||
list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "day": 1}))
|
||||
|
||||
|
||||
def test_simple_code_valid_execution(monkeypatch):
|
||||
simple_code = _build_builtin_tool(SimpleCode)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code",
|
||||
lambda *a: "ok",
|
||||
)
|
||||
result = list(
|
||||
simple_code.invoke(
|
||||
user_id="u",
|
||||
tool_parameters={"language": "python3", "code": "print(1)"},
|
||||
)
|
||||
)[0].message.text
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
def test_simple_code_invalid_language():
|
||||
simple_code = _build_builtin_tool(SimpleCode)
|
||||
|
||||
with pytest.raises(ValueError, match="Only python3 and javascript"):
|
||||
list(simple_code.invoke(user_id="u", tool_parameters={"language": "go", "code": "fmt.Println(1)"}))
|
||||
|
||||
|
||||
def test_simple_code_execution_error(monkeypatch):
|
||||
simple_code = _build_builtin_tool(SimpleCode)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code",
|
||||
_raise_runtime_error,
|
||||
)
|
||||
with pytest.raises(ToolInvokeError, match="boom"):
|
||||
list(simple_code.invoke(user_id="u", tool_parameters={"language": "python3", "code": "print(1)"}))
|
||||
|
||||
|
||||
def test_webscraper_empty_url():
|
||||
webscraper = _build_builtin_tool(WebscraperTool)
|
||||
empty = list(webscraper.invoke(user_id="u", tool_parameters={"url": ""}))[0].message.text
|
||||
assert empty == "Please input url"
|
||||
|
||||
|
||||
def test_webscraper_fetch(monkeypatch):
|
||||
webscraper = _build_builtin_tool(WebscraperTool)
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page")
|
||||
full = list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))[0].message.text
|
||||
assert full == "page"
|
||||
|
||||
|
||||
def test_webscraper_summary(monkeypatch):
|
||||
webscraper = _build_builtin_tool(WebscraperTool)
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page")
|
||||
monkeypatch.setattr(webscraper, "summary", lambda user_id, content: "summary")
|
||||
summarized = list(
|
||||
webscraper.invoke(
|
||||
user_id="u",
|
||||
tool_parameters={"url": "https://example.com", "generate_summary": True},
|
||||
)
|
||||
)[0].message.text
|
||||
assert summarized == "summary"
|
||||
|
||||
|
||||
def test_webscraper_fetch_error(monkeypatch):
|
||||
webscraper = _build_builtin_tool(WebscraperTool)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url",
|
||||
_raise_runtime_error,
|
||||
)
|
||||
with pytest.raises(ToolInvokeError, match="boom"):
|
||||
list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))
|
||||
|
||||
|
||||
def test_asr_invalid_file():
|
||||
asr = _build_builtin_tool(ASRTool)
|
||||
file_obj = SimpleNamespace(type=FileType.DOCUMENT)
|
||||
invalid_file = list(asr.invoke(user_id="u", tool_parameters={"audio_file": file_obj}))[0].message.text
|
||||
assert "not a valid audio file" in invalid_file
|
||||
|
||||
|
||||
def test_asr_valid_file_invocation(monkeypatch):
|
||||
asr = _build_builtin_tool(ASRTool)
|
||||
model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})()
|
||||
model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})()
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes")
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager)
|
||||
audio_file = SimpleNamespace(type=FileType.AUDIO)
|
||||
ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text
|
||||
assert ok == "transcript"
|
||||
|
||||
|
||||
def test_asr_available_models_and_runtime_parameters(monkeypatch):
|
||||
asr = _build_builtin_tool(ASRTool)
|
||||
provider_model = type("PM", (), {"provider": "p", "models": [type("Model", (), {"model": "m"})()]})()
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.asr.ModelProviderService.get_models_by_model_type",
|
||||
lambda *a, **k: [provider_model],
|
||||
)
|
||||
assert asr.get_available_models() == [("p", "m")]
|
||||
assert asr.get_runtime_parameters()[0].name == "model"
|
||||
|
||||
|
||||
def test_tts_invoke_returns_messages(monkeypatch):
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
voices_model_instance = type(
|
||||
"TTSM",
|
||||
(),
|
||||
{
|
||||
"get_tts_voices": lambda self: [{"value": "voice-1"}],
|
||||
"invoke_tts": lambda self, **kwargs: [b"a", b"b"],
|
||||
},
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
|
||||
lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(),
|
||||
)
|
||||
messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
|
||||
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB]
|
||||
|
||||
|
||||
def test_tts_get_available_models_requires_runtime():
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
tts.runtime = None
|
||||
with pytest.raises(ValueError, match="Runtime is required"):
|
||||
tts.get_available_models()
|
||||
|
||||
|
||||
def test_tts_tool_raises_when_runtime_missing():
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
tts.runtime = None
|
||||
with pytest.raises(ValueError, match="Runtime is required"):
|
||||
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"voices",
|
||||
[[{"value": None}], []],
|
||||
)
|
||||
def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices):
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
tts.runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
model_without_voice = type(
|
||||
"TTSModelNoVoice",
|
||||
(),
|
||||
{
|
||||
"get_tts_voices": lambda self: voices,
|
||||
"invoke_tts": lambda self, **kwargs: [b"x"],
|
||||
},
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
|
||||
lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
|
||||
)
|
||||
with pytest.raises(ValueError, match="no voice available"):
|
||||
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
|
||||
|
||||
|
||||
def test_tts_tool_get_available_models_and_runtime_parameters(monkeypatch):
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
|
||||
model_1 = SimpleNamespace(
|
||||
model="model-a",
|
||||
model_properties={ModelPropertyKey.VOICES: [{"mode": "v1", "name": "Voice 1"}]},
|
||||
)
|
||||
model_2 = SimpleNamespace(model="model-b", model_properties={})
|
||||
provider_models = [SimpleNamespace(provider="provider-a", models=[model_1, model_2])]
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelProviderService.get_models_by_model_type",
|
||||
lambda *args, **kwargs: provider_models,
|
||||
)
|
||||
|
||||
available_models = tts.get_available_models()
|
||||
assert available_models == [
|
||||
("provider-a", "model-a", [{"mode": "v1", "name": "Voice 1"}]),
|
||||
("provider-a", "model-b", []),
|
||||
]
|
||||
|
||||
runtime_parameters = tts.get_runtime_parameters()
|
||||
assert runtime_parameters[0].name == "model"
|
||||
assert runtime_parameters[0].required is True
|
||||
assert runtime_parameters[0].options[0].value == "provider-a#model-a"
|
||||
assert runtime_parameters[1].name == "voice#provider-a#model-a"
|
||||
|
||||
|
||||
def test_provider_classes_and_builtin_sort(monkeypatch):
|
||||
# Use object.__new__ to avoid YAML-loading __init__; only pass-through validation is exercised.
|
||||
# Ensure pass-through _validate_credentials methods are executed.
|
||||
AudioToolProvider._validate_credentials(object.__new__(AudioToolProvider), "u", {})
|
||||
CodeToolProvider._validate_credentials(object.__new__(CodeToolProvider), "u", {})
|
||||
WikiPediaProvider._validate_credentials(object.__new__(WikiPediaProvider), "u", {})
|
||||
WebscraperProvider._validate_credentials(object.__new__(WebscraperProvider), "u", {})
|
||||
|
||||
providers = [SimpleNamespace(name="b"), SimpleNamespace(name="a")]
|
||||
monkeypatch.setattr(BuiltinToolProviderSort, "_position", {})
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers._positions.get_tool_position_map",
|
||||
lambda _: {"a": 0, "b": 1},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers._positions.sort_by_position_map",
|
||||
lambda position, values, name_func: sorted(values, key=lambda x: name_func(x)),
|
||||
)
|
||||
sorted_providers = BuiltinToolProviderSort.sort(providers)
|
||||
assert [p.name for p in sorted_providers] == ["a", "b"]
|
||||
285
api/tests/unit_tests/core/tools/test_custom_tool.py
Normal file
285
api/tests/unit_tests/core/tools/test_custom_tool.py
Normal file
@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.tool import ApiTool, ParsedResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
def _build_tool(*, openapi: dict | None = None) -> ApiTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="tool-a",
|
||||
label=I18nObject(en_US="tool-a"),
|
||||
provider="provider-a",
|
||||
),
|
||||
parameters=[],
|
||||
)
|
||||
bundle = ApiToolBundle(
|
||||
server_url="https://api.example.com/items/{id}",
|
||||
method="GET",
|
||||
summary="summary",
|
||||
operation_id="op-id",
|
||||
parameters=[],
|
||||
author="author",
|
||||
openapi=openapi or {"parameters": []},
|
||||
)
|
||||
runtime = ToolRuntime(
|
||||
tenant_id="tenant-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
credentials={"auth_type": "api_key_header", "api_key_value": "k"},
|
||||
)
|
||||
return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="provider-id")
|
||||
|
||||
|
||||
def test_parsed_response_to_string():
|
||||
assert ParsedResponse({"a": 1}, True).to_string() == '{"a": 1}'
|
||||
assert ParsedResponse("ok", False).to_string() == "ok"
|
||||
|
||||
|
||||
def test_api_tool_fork_runtime_and_validate_credentials(monkeypatch):
|
||||
tool = _build_tool()
|
||||
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
|
||||
assert isinstance(forked, ApiTool)
|
||||
assert forked.runtime.tenant_id == "tenant-2"
|
||||
|
||||
tool.api_bundle = None # type: ignore[assignment]
|
||||
with pytest.raises(ValueError, match="api_bundle is required"):
|
||||
tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
|
||||
|
||||
tool = _build_tool()
|
||||
assert tool.validate_credentials(credentials={}, parameters={}, format_only=True) == ""
|
||||
monkeypatch.setattr(tool, "assembling_request", lambda parameters: {"Authorization": "Bearer x"})
|
||||
monkeypatch.setattr(
|
||||
tool,
|
||||
"do_http_request",
|
||||
lambda url, method, headers, parameters: httpx.Response(200, json={"ok": True}),
|
||||
)
|
||||
result = tool.validate_credentials(credentials={}, parameters={"a": 1}, format_only=False)
|
||||
assert result == '{"ok": true}'
|
||||
|
||||
|
||||
def test_assembling_request_auth_header_assembly():
|
||||
tool = _build_tool()
|
||||
|
||||
headers = tool.assembling_request(parameters={})
|
||||
assert headers["Authorization"] == "k"
|
||||
|
||||
tool.runtime.credentials = {
|
||||
"auth_type": "api_key_header",
|
||||
"api_key_header_prefix": "bearer",
|
||||
"api_key_value": "abc",
|
||||
}
|
||||
headers = tool.assembling_request(parameters={})
|
||||
assert headers["Authorization"] == "Bearer abc"
|
||||
|
||||
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_header_prefix": "basic", "api_key_value": "abc"}
|
||||
headers = tool.assembling_request(parameters={})
|
||||
assert headers["Authorization"] == "Basic abc"
|
||||
|
||||
tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_value": "abc"}
|
||||
assert tool.assembling_request(parameters={}) == {}
|
||||
|
||||
|
||||
def test_assembling_request_runtime_auth_errors():
|
||||
tool = _build_tool()
|
||||
|
||||
tool.runtime = None
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="runtime not initialized"):
|
||||
tool.assembling_request(parameters={})
|
||||
|
||||
tool.runtime = ToolRuntime(tenant_id="tenant", credentials={})
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="Missing auth_type"):
|
||||
tool.assembling_request(parameters={})
|
||||
|
||||
tool.runtime.credentials = {"auth_type": "api_key_header"}
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="Missing api_key_value"):
|
||||
tool.assembling_request(parameters={})
|
||||
|
||||
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": 123}
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="must be a string"):
|
||||
tool.assembling_request(parameters={})
|
||||
|
||||
|
||||
def test_assembling_request_parameter_validation_and_defaults():
|
||||
tool = _build_tool()
|
||||
|
||||
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "x"}
|
||||
tool.api_bundle.parameters = [
|
||||
SimpleNamespace(required=True, name="required_param", default=None),
|
||||
]
|
||||
with pytest.raises(ToolParameterValidationError, match="Missing required parameter required_param"):
|
||||
tool.assembling_request(parameters={})
|
||||
|
||||
tool.api_bundle.parameters = [
|
||||
SimpleNamespace(required=True, name="required_param", default="d"),
|
||||
]
|
||||
params = {}
|
||||
tool.assembling_request(parameters=params)
|
||||
assert params["required_param"] == "d"
|
||||
|
||||
|
||||
def test_validate_and_parse_response_branches():
|
||||
tool = _build_tool()
|
||||
|
||||
with pytest.raises(ToolInvokeError, match="status code 500"):
|
||||
tool.validate_and_parse_response(httpx.Response(500, text="boom"))
|
||||
|
||||
empty = tool.validate_and_parse_response(httpx.Response(200, content=b""))
|
||||
assert empty.is_json is False
|
||||
assert "Empty response from the tool" in str(empty.content)
|
||||
|
||||
json_resp = tool.validate_and_parse_response(
|
||||
httpx.Response(200, json={"a": 1}, headers={"content-type": "application/json"})
|
||||
)
|
||||
assert json_resp.is_json is True
|
||||
assert json_resp.content == {"a": 1}
|
||||
|
||||
non_json_type = tool.validate_and_parse_response(
|
||||
httpx.Response(200, text='{"a": 1}', headers={"content-type": "text/plain"})
|
||||
)
|
||||
assert non_json_type.is_json is False
|
||||
assert non_json_type.content == '{"a": 1}'
|
||||
|
||||
plain_resp = tool.validate_and_parse_response(httpx.Response(200, text="plain"))
|
||||
assert plain_resp.is_json is False
|
||||
assert plain_resp.content == "plain"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response type"):
|
||||
tool.validate_and_parse_response("invalid") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_parameter_value_and_type_conversion_helpers():
|
||||
tool = _build_tool()
|
||||
|
||||
assert tool.get_parameter_value({"name": "x"}, {"x": 1}) == 1
|
||||
assert tool.get_parameter_value({"name": "x", "required": False, "schema": {"default": "d"}}, {}) == "d"
|
||||
with pytest.raises(ToolParameterValidationError, match="Missing required parameter x"):
|
||||
tool.get_parameter_value({"name": "x", "required": True}, {})
|
||||
|
||||
assert tool._convert_body_property_any_of({}, "12", [{"type": "integer"}]) == 12
|
||||
assert tool._convert_body_property_any_of({}, "1.5", [{"type": "number"}]) == 1.5
|
||||
assert tool._convert_body_property_any_of({}, "true", [{"type": "boolean"}]) is True
|
||||
assert tool._convert_body_property_any_of({}, "", [{"type": "null"}]) is None
|
||||
assert tool._convert_body_property_any_of({}, "x", [{"anyOf": [{"type": "string"}]}]) == "x"
|
||||
|
||||
assert tool._convert_body_property_type({"type": "integer"}, "1") == 1
|
||||
assert tool._convert_body_property_type({"type": "number"}, "1.2") == 1.2
|
||||
assert tool._convert_body_property_type({"type": "string"}, 1) == "1"
|
||||
assert tool._convert_body_property_type({"type": "boolean"}, 1) is True
|
||||
assert tool._convert_body_property_type({"type": "null"}, None) is None
|
||||
assert tool._convert_body_property_type({"type": "object"}, '{"a":1}') == {"a": 1}
|
||||
assert tool._convert_body_property_type({"type": "array"}, "[1,2]") == [1, 2]
|
||||
assert tool._convert_body_property_type({"type": "invalid"}, "v") == "v"
|
||||
assert tool._convert_body_property_type({"anyOf": [{"type": "integer"}]}, "2") == 2
|
||||
|
||||
|
||||
def test_do_http_request_builds_arguments_and_handles_invalid_method(monkeypatch):
|
||||
openapi = {
|
||||
"parameters": [
|
||||
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}},
|
||||
{"name": "q", "in": "query", "required": False, "schema": {"default": ""}},
|
||||
{"name": "X-Extra", "in": "header", "required": False, "schema": {"default": "x"}},
|
||||
{"name": "sid", "in": "cookie", "required": False, "schema": {"default": "cookie1"}},
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"required": ["count"],
|
||||
"properties": {
|
||||
"count": {"type": "integer"},
|
||||
"name": {"type": "string", "default": "n"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
tool = _build_tool(openapi=openapi)
|
||||
tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_query_param": "key", "api_key_value": "v"}
|
||||
headers = {}
|
||||
captured = {}
|
||||
|
||||
def _fake_get(url, **kwargs):
|
||||
captured["url"] = url
|
||||
captured["kwargs"] = kwargs
|
||||
return httpx.Response(200, text="ok")
|
||||
|
||||
monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.get", _fake_get)
|
||||
response = tool.do_http_request(
|
||||
"https://api.example.com/items/{id}",
|
||||
"GET",
|
||||
headers=headers,
|
||||
parameters={"id": "123", "count": "2", "q": "search"},
|
||||
)
|
||||
|
||||
assert isinstance(response, httpx.Response)
|
||||
assert captured["url"].endswith("/items/123")
|
||||
assert captured["kwargs"]["params"]["q"] == "search"
|
||||
assert captured["kwargs"]["params"]["key"] == "v"
|
||||
assert captured["kwargs"]["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
invalid_method_tool = _build_tool(openapi={"parameters": []})
|
||||
with pytest.raises(ValueError, match="Invalid http method"):
|
||||
invalid_method_tool.do_http_request("https://api.example.com", "TRACE", headers={}, parameters={})
|
||||
|
||||
|
||||
def test_do_http_request_handles_file_upload_and_invoke_paths(monkeypatch):
|
||||
openapi = {
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"multipart/form-data": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"file": {"format": "binary"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
tool = _build_tool(openapi=openapi)
|
||||
tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "k"}
|
||||
fake_file = SimpleNamespace(filename="a.txt", mime_type="text/plain")
|
||||
captured = {}
|
||||
|
||||
def _fake_post(url, **kwargs):
|
||||
captured["headers"] = kwargs["headers"]
|
||||
captured["files"] = kwargs["files"]
|
||||
return httpx.Response(200, text="ok")
|
||||
|
||||
monkeypatch.setattr("core.tools.custom_tool.tool.download", lambda _: b"file-bytes")
|
||||
monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.post", _fake_post)
|
||||
response = tool.do_http_request(
|
||||
"https://api.example.com/upload",
|
||||
"POST",
|
||||
headers={},
|
||||
parameters={"file": fake_file},
|
||||
)
|
||||
assert isinstance(response, httpx.Response)
|
||||
assert "Content-Type" not in captured["headers"]
|
||||
assert captured["files"][0][0] == "file"
|
||||
|
||||
# _invoke JSON path
|
||||
monkeypatch.setattr(tool, "assembling_request", lambda parameters: {})
|
||||
monkeypatch.setattr(tool, "do_http_request", lambda *args, **kwargs: httpx.Response(200, text='{"a":1}'))
|
||||
monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse({"a": 1}, True))
|
||||
messages = list(tool.invoke(user_id="u1", tool_parameters={}))
|
||||
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.TEXT]
|
||||
|
||||
# _invoke text path
|
||||
monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse("plain", False))
|
||||
messages = list(tool.invoke(user_id="u1", tool_parameters={}))
|
||||
assert len(messages) == 1
|
||||
assert messages[0].message.text == "plain"
|
||||
75
api/tests/unit_tests/core/tools/test_custom_tool_provider.py
Normal file
75
api/tests/unit_tests/core/tools/test_custom_tool_provider.py
Normal file
@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolProviderType
|
||||
|
||||
|
||||
def _db_provider() -> SimpleNamespace:
|
||||
bundle = ApiToolBundle(
|
||||
server_url="https://api.example.com/items",
|
||||
method="GET",
|
||||
summary="List items",
|
||||
operation_id="list_items",
|
||||
parameters=[],
|
||||
author="author",
|
||||
openapi={"parameters": []},
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id="provider-id",
|
||||
tenant_id="tenant-1",
|
||||
name="provider-a",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
user=SimpleNamespace(name="Alice"),
|
||||
tools=[bundle],
|
||||
)
|
||||
|
||||
|
||||
def test_api_tool_provider_from_db_and_parse_tool_bundle():
|
||||
controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_HEADER)
|
||||
assert controller.provider_type == ToolProviderType.API
|
||||
assert any(c.name == "api_key_value" for c in controller.entity.credentials_schema)
|
||||
|
||||
tool = controller._parse_tool_bundle(_db_provider().tools[0])
|
||||
assert isinstance(tool, ApiTool)
|
||||
assert tool.entity.identity.provider == "provider-id"
|
||||
|
||||
|
||||
def test_api_tool_provider_from_db_query_auth_and_none_auth():
|
||||
query_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_QUERY)
|
||||
assert any(c.name == "api_key_query_param" for c in query_controller.entity.credentials_schema)
|
||||
|
||||
none_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE)
|
||||
assert [c.name for c in none_controller.entity.credentials_schema] == ["auth_type"]
|
||||
|
||||
|
||||
def test_api_tool_provider_load_get_tools_and_get_tool():
|
||||
controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE)
|
||||
loaded = controller.load_bundled_tools(_db_provider().tools)
|
||||
assert len(loaded) == 1
|
||||
|
||||
assert isinstance(controller.get_tool("list_items"), ApiTool)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
controller.get_tool("missing")
|
||||
|
||||
# Return cached tools without querying database.
|
||||
cached = controller.get_tools("tenant-1")
|
||||
assert len(cached) == 1
|
||||
|
||||
# Force DB fetch branch.
|
||||
controller.tools = []
|
||||
provider_with_tools = _db_provider()
|
||||
with patch("core.tools.custom_tool.provider.db") as mock_db:
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [provider_with_tools]
|
||||
mock_db.session.scalars.return_value = scalars_result
|
||||
tools = controller.get_tools("tenant-1")
|
||||
assert len(tools) == 1
|
||||
145
api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py
Normal file
145
api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Unit tests for DatasetRetrieverTool behavior and retrieval wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
def _retrieve_config() -> DatasetRetrieveConfigEntity:
|
||||
return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE)
|
||||
|
||||
|
||||
def test_get_dataset_tools_returns_empty_for_empty_dataset_ids() -> None:
|
||||
# Arrange
|
||||
retrieve_config = _retrieve_config()
|
||||
|
||||
# Act
|
||||
tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id="tenant",
|
||||
dataset_ids=[],
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
hit_callback=Mock(),
|
||||
user_id="u",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_get_dataset_tools_returns_empty_for_missing_retrieve_config() -> None:
|
||||
# Arrange
|
||||
dataset_ids = ["d1"]
|
||||
|
||||
# Act
|
||||
tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id="tenant",
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=None, # type: ignore[arg-type]
|
||||
return_resource=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
hit_callback=Mock(),
|
||||
user_id="u",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_get_dataset_tools_builds_tool_and_restores_strategy() -> None:
|
||||
# Arrange
|
||||
retrieve_config = _retrieve_config()
|
||||
retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}")
|
||||
feature = Mock()
|
||||
feature.to_dataset_retriever_tool.return_value = [retrieval_tool]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature):
|
||||
tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id="tenant",
|
||||
dataset_ids=["d1"],
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=True,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
hit_callback=Mock(),
|
||||
user_id="u",
|
||||
inputs={"x": 1},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(tools) == 1
|
||||
assert tools[0].entity.identity.name == "dataset_tool"
|
||||
assert retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
|
||||
|
||||
|
||||
def _build_dataset_tool() -> tuple[DatasetRetrieverTool, SimpleNamespace]:
|
||||
retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}")
|
||||
feature = Mock()
|
||||
feature.to_dataset_retriever_tool.return_value = [retrieval_tool]
|
||||
with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature):
|
||||
tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id="tenant",
|
||||
dataset_ids=["d1"],
|
||||
retrieve_config=_retrieve_config(),
|
||||
return_resource=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
hit_callback=Mock(),
|
||||
user_id="u",
|
||||
inputs={},
|
||||
)
|
||||
return tools[0], retrieval_tool
|
||||
|
||||
|
||||
def test_runtime_parameters_shape() -> None:
|
||||
# Arrange
|
||||
tool, _ = _build_dataset_tool()
|
||||
|
||||
# Act
|
||||
params = tool.get_runtime_parameters()
|
||||
|
||||
# Assert
|
||||
assert len(params) == 1
|
||||
assert params[0].name == "query"
|
||||
|
||||
|
||||
def test_empty_query_behavior() -> None:
|
||||
# Arrange
|
||||
tool, _ = _build_dataset_tool()
|
||||
|
||||
# Act
|
||||
empty_query = list(tool.invoke(user_id="u", tool_parameters={}))
|
||||
|
||||
# Assert
|
||||
assert len(empty_query) == 1
|
||||
assert empty_query[0].message.text == "please input query"
|
||||
|
||||
|
||||
def test_query_invocation_result() -> None:
|
||||
# Arrange
|
||||
tool, _ = _build_dataset_tool()
|
||||
|
||||
# Act
|
||||
result = list(tool.invoke(user_id="u", tool_parameters={"query": "hello"}))
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].message.text == "result:hello"
|
||||
|
||||
|
||||
def test_validate_credentials() -> None:
|
||||
# Arrange
|
||||
tool, _ = _build_dataset_tool()
|
||||
|
||||
# Act
|
||||
result = tool.validate_credentials(credentials={}, parameters={}, format_only=False)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
150
api/tests/unit_tests/core/tools/test_mcp_tool.py
Normal file
150
api/tests/unit_tests/core/tools/test_mcp_tool.py
Normal file
@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.mcp.types import (
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
|
||||
|
||||
def _build_mcp_tool(*, with_output_schema: bool = True) -> MCPTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="remote-tool",
|
||||
label=I18nObject(en_US="remote-tool"),
|
||||
provider="provider-id",
|
||||
),
|
||||
parameters=[],
|
||||
output_schema={"type": "object"} if with_output_schema else {},
|
||||
)
|
||||
return MCPTool(
|
||||
entity=entity,
|
||||
runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER),
|
||||
tenant_id="tenant-1",
|
||||
icon="icon.svg",
|
||||
server_url="https://mcp.example.com",
|
||||
provider_id="provider-id",
|
||||
headers={"x-auth": "token"},
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_tool_provider_type_and_fork_runtime():
|
||||
tool = _build_mcp_tool()
|
||||
assert tool.tool_provider_type() == ToolProviderType.MCP
|
||||
|
||||
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
|
||||
assert isinstance(forked, MCPTool)
|
||||
assert forked.runtime.tenant_id == "tenant-2"
|
||||
assert forked.provider_id == "provider-id"
|
||||
|
||||
|
||||
def test_mcp_tool_text_and_json_processing_helpers():
|
||||
tool = _build_mcp_tool()
|
||||
|
||||
json_messages = list(tool._process_text_content(TextContent(type="text", text='{"a": 1}')))
|
||||
assert json_messages[0].type == ToolInvokeMessage.MessageType.JSON
|
||||
|
||||
plain_messages = list(tool._process_text_content(TextContent(type="text", text="not-json")))
|
||||
assert plain_messages[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert plain_messages[0].message.text == "not-json"
|
||||
|
||||
list_messages = list(tool._process_json_content([{"k": 1}, {"k": 2}]))
|
||||
assert [m.type for m in list_messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
mixed_list_messages = list(tool._process_json_list([{"k": 1}, 2]))
|
||||
assert len(mixed_list_messages) == 1
|
||||
assert mixed_list_messages[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
|
||||
primitive_messages = list(tool._process_json_content(123))
|
||||
assert primitive_messages[0].message.text == "123"
|
||||
|
||||
|
||||
def test_mcp_tool_usage_extraction_helpers():
|
||||
usage = MCPTool._extract_usage_dict({"usage": {"total_tokens": 9}})
|
||||
assert usage == {"total_tokens": 9}
|
||||
|
||||
usage = MCPTool._extract_usage_dict({"metadata": {"usage": {"prompt_tokens": 3, "completion_tokens": 2}}})
|
||||
assert usage == {"prompt_tokens": 3, "completion_tokens": 2}
|
||||
|
||||
usage = MCPTool._extract_usage_dict({"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3})
|
||||
assert usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
|
||||
|
||||
usage = MCPTool._extract_usage_dict({"nested": [{"deep": {"usage": {"total_tokens": 7}}}]})
|
||||
assert usage == {"total_tokens": 7}
|
||||
|
||||
result_with_usage = CallToolResult(content=[], _meta={"usage": {"prompt_tokens": 1, "completion_tokens": 2}})
|
||||
derived = MCPTool._derive_usage_from_result(result_with_usage)
|
||||
assert derived.prompt_tokens == 1
|
||||
assert derived.completion_tokens == 2
|
||||
|
||||
result_without_usage = CallToolResult(content=[], _meta=None)
|
||||
derived = MCPTool._derive_usage_from_result(result_without_usage)
|
||||
assert derived.total_tokens == 0
|
||||
|
||||
|
||||
def test_mcp_tool_invoke_handles_content_types_and_structured_output():
|
||||
tool = _build_mcp_tool()
|
||||
img_data = base64.b64encode(b"img").decode()
|
||||
blob_data = base64.b64encode(b"blob").decode()
|
||||
result = CallToolResult(
|
||||
content=[
|
||||
TextContent(type="text", text='{"a": 1}'),
|
||||
ImageContent(type="image", data=img_data, mimeType="image/png"),
|
||||
EmbeddedResource(
|
||||
type="resource",
|
||||
resource=TextResourceContents(uri="file:///tmp/a.txt", text="embedded-text"),
|
||||
),
|
||||
EmbeddedResource(
|
||||
type="resource",
|
||||
resource=BlobResourceContents(
|
||||
uri="file:///tmp/b.bin",
|
||||
blob=blob_data,
|
||||
mimeType="application/octet-stream",
|
||||
),
|
||||
),
|
||||
],
|
||||
structuredContent={"x": 1},
|
||||
_meta={"usage": {"prompt_tokens": 2, "completion_tokens": 3}},
|
||||
)
|
||||
|
||||
with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool.invoke(user_id="user-1", tool_parameters={"a": 1}))
|
||||
|
||||
types = [m.type for m in messages]
|
||||
assert ToolInvokeMessage.MessageType.JSON in types
|
||||
assert ToolInvokeMessage.MessageType.BLOB in types
|
||||
assert ToolInvokeMessage.MessageType.TEXT in types
|
||||
assert ToolInvokeMessage.MessageType.VARIABLE in types
|
||||
assert tool.latest_usage.total_tokens == 5
|
||||
|
||||
|
||||
def test_mcp_tool_invoke_raises_for_unsupported_embedded_resource():
|
||||
tool = _build_mcp_tool()
|
||||
# Use model_construct to bypass pydantic validation and force unsupported resource path.
|
||||
bad_resource = EmbeddedResource.model_construct(type="resource", resource=object())
|
||||
result = CallToolResult(content=[bad_resource], _meta=None)
|
||||
|
||||
with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result):
|
||||
with pytest.raises(ToolInvokeError, match="Unsupported embedded resource type"):
|
||||
list(tool.invoke(user_id="user-1", tool_parameters={}))
|
||||
|
||||
|
||||
def test_mcp_tool_handle_none_parameter_filters_empty_values():
|
||||
tool = _build_mcp_tool()
|
||||
cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"})
|
||||
assert cleaned == {"a": 1, "e": "ok"}
|
||||
73
api/tests/unit_tests/core/tools/test_mcp_tool_provider.py
Normal file
73
api/tests/unit_tests/core/tools/test_mcp_tool_provider.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
|
||||
|
||||
def _build_mcp_entity(*, icon: str = "icon.svg") -> MCPProviderEntity:
|
||||
now = datetime.now()
|
||||
return MCPProviderEntity(
|
||||
id="db-id",
|
||||
provider_id="provider-id",
|
||||
name="MCP Provider",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
server_url="https://mcp.example.com",
|
||||
headers={},
|
||||
timeout=30,
|
||||
sse_read_timeout=300,
|
||||
authed=False,
|
||||
credentials={},
|
||||
tools=[
|
||||
{
|
||||
"name": "remote-tool",
|
||||
"description": "remote tool",
|
||||
"inputSchema": {},
|
||||
"outputSchema": {"type": "object"},
|
||||
}
|
||||
],
|
||||
icon=icon,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_tool_provider_controller_from_entity_and_get_tools():
|
||||
entity = _build_mcp_entity()
|
||||
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
|
||||
controller = MCPToolProviderController.from_entity(entity)
|
||||
|
||||
assert controller.provider_type == ToolProviderType.MCP
|
||||
tool = controller.get_tool("remote-tool")
|
||||
assert isinstance(tool, MCPTool)
|
||||
assert tool.tenant_id == "tenant-1"
|
||||
|
||||
tools = controller.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], MCPTool)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
controller.get_tool("missing")
|
||||
|
||||
|
||||
def test_mcp_tool_provider_controller_from_entity_requires_icon():
|
||||
entity = _build_mcp_entity(icon="")
|
||||
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
|
||||
with pytest.raises(ValueError, match="icon is required"):
|
||||
MCPToolProviderController.from_entity(entity)
|
||||
|
||||
|
||||
def test_mcp_tool_provider_controller_from_db_delegates_to_entity():
|
||||
entity = _build_mcp_entity()
|
||||
db_provider = Mock()
|
||||
db_provider.to_entity.return_value = entity
|
||||
with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]):
|
||||
controller = MCPToolProviderController.from_db(db_provider)
|
||||
assert isinstance(controller, MCPToolProviderController)
|
||||
91
api/tests/unit_tests/core/tools/test_plugin_tool.py
Normal file
91
api/tests/unit_tests/core/tools/test_plugin_tool.py
Normal file
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolParameter
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
def _build_plugin_tool(*, has_runtime_parameters: bool) -> PluginTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="tool-a",
|
||||
label=I18nObject(en_US="tool-a"),
|
||||
provider="provider-a",
|
||||
),
|
||||
parameters=[
|
||||
ToolParameter.get_simple_instance(
|
||||
name="query",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
has_runtime_parameters=has_runtime_parameters,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, credentials={"api_key": "x"})
|
||||
return PluginTool(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id="tenant-1",
|
||||
icon="icon.svg",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
)
|
||||
|
||||
|
||||
def test_plugin_tool_invoke_and_fork_runtime():
|
||||
tool = _build_plugin_tool(has_runtime_parameters=False)
|
||||
manager = Mock()
|
||||
manager.invoke.return_value = iter([tool.create_text_message("ok")])
|
||||
|
||||
with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager):
|
||||
with patch(
|
||||
"core.tools.plugin_tool.tool.convert_parameters_to_plugin_format",
|
||||
return_value={"converted": 1},
|
||||
):
|
||||
messages = list(tool.invoke(user_id="user-1", tool_parameters={"raw": 1}))
|
||||
|
||||
assert [m.message.text for m in messages] == ["ok"]
|
||||
manager.invoke.assert_called_once()
|
||||
assert manager.invoke.call_args.kwargs["tool_parameters"] == {"converted": 1}
|
||||
|
||||
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2"))
|
||||
assert isinstance(forked, PluginTool)
|
||||
assert forked.runtime.tenant_id == "tenant-2"
|
||||
assert forked.plugin_unique_identifier == "plugin-uid"
|
||||
|
||||
|
||||
def test_plugin_tool_get_runtime_parameters_branches():
|
||||
tool = _build_plugin_tool(has_runtime_parameters=False)
|
||||
assert tool.get_runtime_parameters() == tool.entity.parameters
|
||||
|
||||
tool = _build_plugin_tool(has_runtime_parameters=True)
|
||||
cached = [
|
||||
ToolParameter.get_simple_instance(
|
||||
name="k",
|
||||
llm_description="k",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
tool.runtime_parameters = cached
|
||||
assert tool.get_runtime_parameters() == cached
|
||||
|
||||
tool.runtime_parameters = None
|
||||
manager = Mock()
|
||||
returned = [
|
||||
ToolParameter.get_simple_instance(
|
||||
name="dyn",
|
||||
llm_description="dyn",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
manager.get_runtime_parameters.return_value = returned
|
||||
with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager):
|
||||
assert tool.get_runtime_parameters(conversation_id="c1", app_id="a1", message_id="m1") == returned
|
||||
assert tool.runtime_parameters == returned
|
||||
89
api/tests/unit_tests/core/tools/test_plugin_tool_provider.py
Normal file
89
api/tests/unit_tests/core/tools/test_plugin_tool_provider.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
def _build_controller() -> PluginToolProviderController:
|
||||
tool_entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name="tool-a",
|
||||
label=I18nObject(en_US="tool-a"),
|
||||
provider="provider-a",
|
||||
),
|
||||
parameters=[],
|
||||
)
|
||||
entity = ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author="author",
|
||||
name="provider-a",
|
||||
description=I18nObject(en_US="desc"),
|
||||
icon="icon.svg",
|
||||
label=I18nObject(en_US="Provider"),
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id="plugin-id",
|
||||
tools=[tool_entity],
|
||||
)
|
||||
return PluginToolProviderController(
|
||||
entity=entity,
|
||||
plugin_id="plugin-id",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
|
||||
def test_plugin_tool_provider_controller_basic_behaviors():
|
||||
controller = _build_controller()
|
||||
assert controller.provider_type == ToolProviderType.PLUGIN
|
||||
|
||||
tool = controller.get_tool("tool-a")
|
||||
assert isinstance(tool, PluginTool)
|
||||
assert tool.runtime.tenant_id == "tenant-1"
|
||||
|
||||
tools = controller.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], PluginTool)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
controller.get_tool("missing")
|
||||
|
||||
|
||||
def test_validate_credentials_success():
|
||||
controller = _build_controller()
|
||||
manager = Mock()
|
||||
manager.validate_provider_credentials.return_value = True
|
||||
|
||||
with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager):
|
||||
controller._validate_credentials(user_id="u1", credentials={"api_key": "x"})
|
||||
|
||||
manager.validate_provider_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
user_id="u1",
|
||||
provider="provider-a",
|
||||
credentials={"api_key": "x"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_failure():
|
||||
controller = _build_controller()
|
||||
manager = Mock()
|
||||
manager.validate_provider_credentials.return_value = False
|
||||
|
||||
with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager):
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"):
|
||||
controller._validate_credentials(user_id="u1", credentials={"api_key": "x"})
|
||||
119
api/tests/unit_tests/core/tools/test_signature.py
Normal file
119
api/tests/unit_tests/core/tools/test_signature.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""Unit tests for core.tools.signature covering signing and verification invariants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature
|
||||
|
||||
|
||||
def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x01" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120)
|
||||
|
||||
url = sign_tool_file("tool-file-id", ".png", for_external=False)
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
timestamp = query["timestamp"][0]
|
||||
nonce = query["nonce"][0]
|
||||
sign = query["sign"][0]
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "internal.example.com"
|
||||
assert parsed.path == "/files/tools/tool-file-id.png"
|
||||
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is True
|
||||
|
||||
|
||||
def test_sign_tool_file_for_external_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x04" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120)
|
||||
|
||||
url = sign_tool_file("tool-file-id", ".png", for_external=True)
|
||||
parsed = urlparse(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "files.example.com"
|
||||
assert parsed.path == "/files/tools/tool-file-id.png"
|
||||
|
||||
|
||||
def test_verify_tool_file_signature_rejects_invalid_sign(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10)
|
||||
|
||||
url = sign_tool_file("tool-file-id", ".txt")
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
timestamp = query["timestamp"][0]
|
||||
nonce = query["nonce"][0]
|
||||
sign = query["sign"][0]
|
||||
|
||||
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, "bad-signature") is False
|
||||
|
||||
|
||||
def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10)
|
||||
|
||||
url = sign_tool_file("tool-file-id", ".txt")
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
timestamp = query["timestamp"][0]
|
||||
nonce = query["nonce"][0]
|
||||
sign = query["sign"][0]
|
||||
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: int(timestamp) + 99)
|
||||
assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False
|
||||
|
||||
|
||||
def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
|
||||
url = sign_upload_file("upload-id", ".png")
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
assert parsed.netloc == "internal.example.com"
|
||||
assert parsed.path == "/files/upload-id/image-preview"
|
||||
assert query["timestamp"][0]
|
||||
assert query["nonce"][0]
|
||||
assert query["sign"][0]
|
||||
|
||||
|
||||
def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
|
||||
|
||||
url = sign_upload_file("upload-id", ".png")
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
assert parsed.netloc == "files.example.com"
|
||||
assert parsed.path == "/files/upload-id/image-preview"
|
||||
assert query["timestamp"][0]
|
||||
assert query["nonce"][0]
|
||||
assert query["sign"][0]
|
||||
280
api/tests/unit_tests/core/tools/test_tool_engine.py
Normal file
280
api/tests/unit_tests/core/tools/test_tool_engine.py
Normal file
@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolInvokeMessageBinary,
|
||||
ToolInvokeMeta,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolParameterValidationError,
|
||||
)
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
|
||||
|
||||
class _DummyTool(Tool):
|
||||
result: Any
|
||||
raise_error: Exception | None
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
self.result = [self.create_text_message("ok")]
|
||||
self.raise_error = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
if self.raise_error:
|
||||
raise self.raise_error
|
||||
if isinstance(self.result, list | Generator):
|
||||
yield from self.result
|
||||
else:
|
||||
yield self.result
|
||||
|
||||
|
||||
def _build_tool(with_llm_parameter: bool = False) -> _DummyTool:
|
||||
parameters = []
|
||||
if with_llm_parameter:
|
||||
parameters = [
|
||||
ToolParameter.get_simple_instance(
|
||||
name="query",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
|
||||
parameters=parameters,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, runtime_parameters={"rt": 1})
|
||||
return _DummyTool(entity=entity, runtime=runtime)
|
||||
|
||||
|
||||
def test_convert_tool_response_to_str_and_extract_binary_messages():
|
||||
tool = _build_tool()
|
||||
messages = [
|
||||
tool.create_text_message("hello"),
|
||||
tool.create_link_message("https://example.com"),
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=ToolInvokeMessage.TextMessage(text="https://example.com/a.png"),
|
||||
meta={"mime_type": "image/png"},
|
||||
),
|
||||
tool.create_json_message({"a": 1}),
|
||||
tool.create_json_message({"a": 1}, suppress_output=True),
|
||||
]
|
||||
text = ToolEngine._convert_tool_response_to_str(messages)
|
||||
assert "hello" in text
|
||||
assert "result link: https://example.com." in text
|
||||
assert '"a": 1' in text
|
||||
|
||||
blob_message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.TextMessage(text="https://example.com/blob.bin"),
|
||||
meta={"mime_type": "application/octet-stream"},
|
||||
)
|
||||
link_message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text="https://example.com/file.pdf"),
|
||||
meta={"mime_type": "application/pdf"},
|
||||
)
|
||||
binaries = list(ToolEngine._extract_tool_response_binary_and_text([messages[2], blob_message, link_message]))
|
||||
assert [b.mimetype for b in binaries] == ["image/png", "application/octet-stream", "application/pdf"]
|
||||
|
||||
with pytest.raises(ValueError, match="missing meta data"):
|
||||
list(
|
||||
ToolEngine._extract_tool_response_binary_and_text(
|
||||
[
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=ToolInvokeMessage.TextMessage(text="x"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_create_message_files_and_invoke_generator():
|
||||
binaries = [
|
||||
ToolInvokeMessageBinary(mimetype="image/png", url="https://example.com/abc.png"),
|
||||
ToolInvokeMessageBinary(mimetype="audio/wav", url="https://example.com/def.wav"),
|
||||
]
|
||||
created = []
|
||||
|
||||
def _message_file_factory(**kwargs):
|
||||
obj = SimpleNamespace(id=f"mf-{len(created) + 1}", **kwargs)
|
||||
created.append(obj)
|
||||
return obj
|
||||
|
||||
with patch("core.tools.tool_engine.MessageFile", side_effect=_message_file_factory):
|
||||
with patch("core.tools.tool_engine.db") as mock_db:
|
||||
ids = ToolEngine._create_message_files(
|
||||
tool_messages=binaries,
|
||||
agent_message=SimpleNamespace(id="msg-1"),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert ids == ["mf-1", "mf-2"]
|
||||
assert mock_db.session.add.call_count == 2
|
||||
mock_db.session.close.assert_called_once()
|
||||
|
||||
tool = _build_tool()
|
||||
invoked = list(ToolEngine._invoke(tool, {"a": 1}, user_id="u"))
|
||||
assert invoked[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(invoked[-1], ToolInvokeMeta)
|
||||
assert invoked[-1].error is None
|
||||
|
||||
|
||||
def test_generic_invoke_success_and_error_paths():
|
||||
tool = _build_tool()
|
||||
callback = Mock()
|
||||
callback.on_tool_execution.side_effect = lambda **kwargs: kwargs["tool_outputs"]
|
||||
response = list(
|
||||
ToolEngine.generic_invoke(
|
||||
tool=tool,
|
||||
tool_parameters={"x": 1},
|
||||
user_id="u1",
|
||||
workflow_tool_callback=callback,
|
||||
workflow_call_depth=0,
|
||||
conversation_id="c1",
|
||||
app_id="a1",
|
||||
message_id="m1",
|
||||
)
|
||||
)
|
||||
assert response[0].message.text == "ok"
|
||||
callback.on_tool_start.assert_called_once()
|
||||
callback.on_tool_execution.assert_called_once()
|
||||
|
||||
tool.raise_error = RuntimeError("boom")
|
||||
error_callback = Mock()
|
||||
error_callback.on_tool_execution.side_effect = lambda **kwargs: list(kwargs["tool_outputs"])
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
list(
|
||||
ToolEngine.generic_invoke(
|
||||
tool=tool,
|
||||
tool_parameters={"x": 1},
|
||||
user_id="u1",
|
||||
workflow_tool_callback=error_callback,
|
||||
workflow_call_depth=0,
|
||||
)
|
||||
)
|
||||
error_callback.on_tool_error.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_invoke_success():
|
||||
tool = _build_tool(with_llm_parameter=True)
|
||||
callback = Mock()
|
||||
message = SimpleNamespace(id="m1", conversation_id="c1")
|
||||
meta = ToolInvokeMeta.empty()
|
||||
|
||||
with patch.object(ToolEngine, "_invoke", return_value=iter([tool.create_text_message("ok"), meta])):
|
||||
with patch(
|
||||
"core.tools.tool_engine.ToolFileMessageTransformer.transform_tool_invoke_messages",
|
||||
side_effect=lambda messages, **kwargs: messages,
|
||||
):
|
||||
with patch.object(ToolEngine, "_extract_tool_response_binary_and_text", return_value=iter([])):
|
||||
with patch.object(ToolEngine, "_create_message_files", return_value=[]):
|
||||
result_text, message_files, result_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters="hello",
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
message=message,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
agent_tool_callback=callback,
|
||||
)
|
||||
|
||||
assert result_text == "ok"
|
||||
assert message_files == []
|
||||
assert result_meta.error is None
|
||||
callback.on_tool_start.assert_called_once()
|
||||
callback.on_tool_end.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_invoke_param_validation_error():
|
||||
tool = _build_tool(with_llm_parameter=True)
|
||||
callback = Mock()
|
||||
message = SimpleNamespace(id="m1", conversation_id="c1")
|
||||
|
||||
with patch.object(ToolEngine, "_invoke", side_effect=ToolParameterValidationError("bad-param")):
|
||||
error_text, files, error_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters={"a": 1},
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
message=message,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
agent_tool_callback=callback,
|
||||
)
|
||||
|
||||
assert "tool parameters validation error" in error_text
|
||||
assert files == []
|
||||
assert error_meta.error
|
||||
|
||||
|
||||
def test_agent_invoke_engine_meta_error():
|
||||
tool = _build_tool(with_llm_parameter=True)
|
||||
callback = Mock()
|
||||
message = SimpleNamespace(id="m1", conversation_id="c1")
|
||||
engine_error = ToolEngineInvokeError(ToolInvokeMeta.error_instance("meta failure"))
|
||||
|
||||
with patch.object(ToolEngine, "_invoke", side_effect=engine_error):
|
||||
error_text, files, error_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters={"a": 1},
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
message=message,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
agent_tool_callback=callback,
|
||||
)
|
||||
|
||||
assert "meta failure" in error_text
|
||||
assert files == []
|
||||
assert error_meta.error == "meta failure"
|
||||
|
||||
|
||||
def test_agent_invoke_tool_invoke_error():
|
||||
tool = _build_tool(with_llm_parameter=True)
|
||||
callback = Mock()
|
||||
message = SimpleNamespace(id="m1", conversation_id="c1")
|
||||
|
||||
with patch.object(ToolEngine, "_invoke", side_effect=ToolInvokeError("invoke boom")):
|
||||
error_text, files, _ = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters={"a": 1},
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
message=message,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
agent_tool_callback=callback,
|
||||
)
|
||||
|
||||
assert "tool invoke error" in error_text
|
||||
assert files == []
|
||||
249
api/tests/unit_tests/core/tools/test_tool_file_manager.py
Normal file
249
api/tests/unit_tests/core/tools/test_tool_file_manager.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Unit tests for `ToolFileManager` behavior.
|
||||
|
||||
Covers signing/verification, file persistence flows, and retrieval APIs with
|
||||
mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to
|
||||
avoid real IO.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
|
||||
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100)
|
||||
|
||||
url = ToolFileManager.sign_file("tf-1", ".png")
|
||||
return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&"))
|
||||
|
||||
|
||||
def _patch_session_factory(session: Mock):
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm)
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
url = ToolFileManager.sign_file("tf-1", ".png")
|
||||
assert "/files/tools/tf-1.png" in url
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100)
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False
|
||||
|
||||
|
||||
def test_create_file_by_raw_stores_file_and_persists_record() -> None:
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.refresh.side_effect = lambda model: setattr(model, "id", "tf-1")
|
||||
|
||||
def tool_file_factory(**kwargs):
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
with (
|
||||
patch("core.tools.tool_file_manager.storage") as storage,
|
||||
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
|
||||
patch("core.tools.tool_file_manager.guess_extension", return_value=".txt"),
|
||||
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="abc")),
|
||||
_patch_session_factory(session),
|
||||
):
|
||||
file_model = manager.create_file_by_raw(
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
conversation_id="c1",
|
||||
file_binary=b"hello",
|
||||
mimetype="text/plain",
|
||||
filename="readme",
|
||||
)
|
||||
|
||||
assert file_model.name.endswith(".txt")
|
||||
storage.save.assert_called_once()
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(file_model)
|
||||
|
||||
|
||||
def test_create_file_by_url_downloads_and_persists_record() -> None:
|
||||
manager = ToolFileManager()
|
||||
response = Mock()
|
||||
response.content = b"binary"
|
||||
response.headers = {"Content-Type": "application/octet-stream"}
|
||||
response.raise_for_status.return_value = None
|
||||
session = Mock()
|
||||
|
||||
def tool_file_factory(**kwargs):
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
session.refresh.side_effect = lambda model: setattr(model, "id", "tf-2")
|
||||
with (
|
||||
patch("core.tools.tool_file_manager.storage") as storage,
|
||||
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
|
||||
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")),
|
||||
_patch_session_factory(session),
|
||||
patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response),
|
||||
):
|
||||
file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
|
||||
|
||||
assert file_model.file_key.startswith("tools/t1/")
|
||||
storage.save.assert_called_once()
|
||||
session.add.assert_called_once_with(file_model)
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(file_model)
|
||||
|
||||
|
||||
def test_create_file_by_url_raises_on_timeout() -> None:
|
||||
manager = ToolFileManager()
|
||||
|
||||
with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")):
|
||||
with pytest.raises(ValueError, match="timeout when downloading file"):
|
||||
manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
|
||||
|
||||
|
||||
def test_get_file_binary_returns_none_when_not_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
result = manager.get_file_binary("missing")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_file_binary_returns_bytes_when_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
storage.load_once.return_value = b"hello"
|
||||
with _patch_session_factory(session):
|
||||
result = manager.get_file_binary("id1")
|
||||
|
||||
# Assert
|
||||
assert result == (b"hello", "text/plain")
|
||||
|
||||
|
||||
def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
result = manager.get_file_binary_by_message_file_id("mf-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url=None)
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
result = manager.get_file_binary_by_message_file_id("mf-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = tool_file
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
storage.load_once.return_value = b"img"
|
||||
with _patch_session_factory(session):
|
||||
result = manager.get_file_binary_by_message_file_id("mf-1")
|
||||
|
||||
# Assert
|
||||
assert result == (b"img", "image/png")
|
||||
|
||||
|
||||
def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
stream, tool_file = manager.get_file_generator_by_tool_file_id("tool123")
|
||||
|
||||
# Assert
|
||||
assert stream is None
|
||||
assert tool_file is None
|
||||
|
||||
|
||||
def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
stream = iter([b"a", b"b"])
|
||||
storage.load_stream.return_value = stream
|
||||
with (
|
||||
_patch_session_factory(session),
|
||||
patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"),
|
||||
):
|
||||
result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123")
|
||||
assert list(result_stream) == [b"a", b"b"]
|
||||
assert result_file == "validated-file"
|
||||
92
api/tests/unit_tests/core/tools/test_tool_label_manager.py
Normal file
92
api/tests/unit_tests/core/tools/test_tool_label_manager.py
Normal file
@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
|
||||
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
return None
|
||||
|
||||
|
||||
def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController:
|
||||
controller = object.__new__(ApiToolProviderController)
|
||||
controller.provider_id = provider_id
|
||||
return controller
|
||||
|
||||
|
||||
def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderController:
|
||||
controller = object.__new__(WorkflowToolProviderController)
|
||||
controller.provider_id = provider_id
|
||||
return controller
|
||||
|
||||
|
||||
def test_tool_label_manager_filter_tool_labels():
|
||||
filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"])
|
||||
assert set(filtered) == {"search", "news"}
|
||||
assert len(filtered) == 2
|
||||
|
||||
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
delete_query = mock_db.session.query.return_value.where.return_value
|
||||
delete_query.delete.return_value = None
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
delete_query.delete.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_tool_label_manager_update_tool_labels_unsupported():
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_tool_label_manager_get_tool_labels_for_builtin_and_db():
|
||||
with patch.object(
|
||||
_ConcreteBuiltinToolProviderController,
|
||||
"tool_labels",
|
||||
new_callable=PropertyMock,
|
||||
return_value=["search", "news"],
|
||||
):
|
||||
builtin = object.__new__(_ConcreteBuiltinToolProviderController)
|
||||
assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"]
|
||||
|
||||
api = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = ["search", "news"]
|
||||
labels = ToolLabelManager.get_tool_labels(api)
|
||||
assert labels == ["search", "news"]
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_tool_label_manager_get_tools_labels_batch():
|
||||
assert ToolLabelManager.get_tools_labels([]) == {}
|
||||
|
||||
api = _api_controller("api-1")
|
||||
wf = _workflow_controller("wf-1")
|
||||
records = [
|
||||
SimpleNamespace(tool_id="api-1", label_name="search"),
|
||||
SimpleNamespace(tool_id="api-1", label_name="news"),
|
||||
SimpleNamespace(tool_id="wf-1", label_name="utilities"),
|
||||
]
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = records
|
||||
labels = ToolLabelManager.get_tools_labels([api, wf])
|
||||
assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported tool type"):
|
||||
ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item]
|
||||
899
api/tests/unit_tests/core/tools/test_tool_manager.py
Normal file
899
api/tests/unit_tests/core/tools/test_tool_manager.py
Normal file
@ -0,0 +1,899 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Unit tests for ToolManager behavior with mocked providers and collaborators."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
class _SimpleContextVar:
|
||||
def __init__(self):
|
||||
self._is_set = False
|
||||
self._value: Any = None
|
||||
|
||||
def get(self):
|
||||
if not self._is_set:
|
||||
raise LookupError
|
||||
return self._value
|
||||
|
||||
def set(self, value: Any):
|
||||
self._value = value
|
||||
self._is_set = True
|
||||
|
||||
|
||||
def _cm(session: Any):
|
||||
context = Mock()
|
||||
context.__enter__ = Mock(return_value=session)
|
||||
context.__exit__ = Mock(return_value=False)
|
||||
return context
|
||||
|
||||
|
||||
def _setup_list_providers_from_api_mocks(
|
||||
monkeypatch,
|
||||
*,
|
||||
session: Mock,
|
||||
hardcoded_controller: SimpleNamespace,
|
||||
plugin_controller: PluginToolProviderController,
|
||||
api_controller: SimpleNamespace,
|
||||
workflow_controller: SimpleNamespace,
|
||||
):
|
||||
mock_db = Mock()
|
||||
mock_db.engine = object()
|
||||
monkeypatch.setattr("core.tools.tool_manager.db", mock_db)
|
||||
monkeypatch.setattr("core.tools.tool_manager.Session", lambda *args, **kwargs: _cm(session))
|
||||
monkeypatch.setattr(
|
||||
ToolManager,
|
||||
"list_builtin_providers",
|
||||
Mock(return_value=[hardcoded_controller, plugin_controller]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ToolManager,
|
||||
"list_default_builtin_providers",
|
||||
Mock(return_value=[SimpleNamespace(provider="hardcoded")]),
|
||||
)
|
||||
monkeypatch.setattr("core.tools.tool_manager.is_filtered", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolTransformService.builtin_provider_to_user_provider",
|
||||
lambda **kwargs: SimpleNamespace(name=kwargs["provider_controller"].entity.identity.name),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolTransformService.api_provider_to_controller",
|
||||
Mock(side_effect=[api_controller, RuntimeError("invalid")]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolTransformService.api_provider_to_user_provider",
|
||||
Mock(return_value=SimpleNamespace(name="api-provider")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller",
|
||||
Mock(side_effect=[workflow_controller, RuntimeError("deleted app")]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_user_provider",
|
||||
Mock(return_value=SimpleNamespace(name="workflow-provider")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolLabelManager.get_tools_labels",
|
||||
Mock(side_effect=[{"api-1": ["search"]}, {"wf-1": ["utility"]}]),
|
||||
)
|
||||
mock_mcp_service = Mock()
|
||||
mock_mcp_service.list_providers.return_value = [SimpleNamespace(name="mcp-provider")]
|
||||
monkeypatch.setattr("core.tools.tool_manager.MCPToolManageService", Mock(return_value=mock_mcp_service))
|
||||
monkeypatch.setattr("core.tools.tool_manager.BuiltinToolProviderSort.sort", lambda providers: providers)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tool_manager_state():
|
||||
old_hardcoded = ToolManager._hardcoded_providers.copy()
|
||||
old_loaded = ToolManager._builtin_providers_loaded
|
||||
old_labels = ToolManager._builtin_tools_labels.copy()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ToolManager._hardcoded_providers = old_hardcoded
|
||||
ToolManager._builtin_providers_loaded = old_loaded
|
||||
ToolManager._builtin_tools_labels = old_labels
|
||||
|
||||
|
||||
def test_get_hardcoded_provider_loads_cache_when_empty():
|
||||
provider = Mock()
|
||||
ToolManager._hardcoded_providers = {}
|
||||
|
||||
def _load():
|
||||
ToolManager._hardcoded_providers["weather"] = provider
|
||||
|
||||
with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load) as mock_load:
|
||||
assert ToolManager.get_hardcoded_provider("weather") is provider
|
||||
|
||||
mock_load.assert_called_once()
|
||||
|
||||
|
||||
def test_get_builtin_provider_returns_plugin_for_missing_hardcoded():
|
||||
hardcoded = Mock()
|
||||
plugin_provider = Mock()
|
||||
ToolManager._hardcoded_providers = {"time": hardcoded}
|
||||
|
||||
with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider):
|
||||
assert ToolManager.get_builtin_provider("time", "tenant-1") is hardcoded
|
||||
assert ToolManager.get_builtin_provider("plugin/time", "tenant-1") is plugin_provider
|
||||
|
||||
|
||||
def test_get_plugin_provider_uses_context_cache():
|
||||
provider_context = _SimpleContextVar()
|
||||
lock_context = _SimpleContextVar()
|
||||
lock_context.set(threading.Lock())
|
||||
provider_entity = SimpleNamespace(declaration=Mock(), plugin_id="pid", plugin_unique_identifier="uid")
|
||||
|
||||
with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context):
|
||||
with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context):
|
||||
with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls:
|
||||
mock_manager_cls.return_value.fetch_tool_provider.return_value = provider_entity
|
||||
controller = SimpleNamespace(name="controller")
|
||||
with patch("core.tools.tool_manager.PluginToolProviderController", return_value=controller):
|
||||
built = ToolManager.get_plugin_provider("provider-a", "tenant-1")
|
||||
cached = ToolManager.get_plugin_provider("provider-a", "tenant-1")
|
||||
|
||||
assert built is controller
|
||||
assert cached is controller
|
||||
mock_manager_cls.return_value.fetch_tool_provider.assert_called_once()
|
||||
|
||||
|
||||
def test_get_plugin_provider_raises_when_provider_missing():
|
||||
provider_context = _SimpleContextVar()
|
||||
lock_context = _SimpleContextVar()
|
||||
lock_context.set(threading.Lock())
|
||||
|
||||
with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context):
|
||||
with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context):
|
||||
with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls:
|
||||
mock_manager_cls.return_value.fetch_tool_provider.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="plugin provider provider-a not found"):
|
||||
ToolManager.get_plugin_provider("provider-a", "tenant-1")
|
||||
|
||||
|
||||
def test_get_tool_runtime_builtin_without_credentials():
|
||||
tool = Mock()
|
||||
tool.fork_tool_runtime.return_value = "runtime-tool"
|
||||
controller = SimpleNamespace(get_tool=Mock(return_value=tool), need_credentials=False)
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
result = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="time",
|
||||
tool_name="current_time",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
assert result == "runtime-tool"
|
||||
runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"]
|
||||
assert runtime.tenant_id == "tenant-1"
|
||||
assert runtime.credentials == {}
|
||||
|
||||
|
||||
def test_get_tool_runtime_builtin_missing_tool_raises():
|
||||
controller = SimpleNamespace(get_tool=Mock(return_value=None), need_credentials=False)
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="builtin tool missing not found"):
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="time",
|
||||
tool_name="missing",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
|
||||
tool = Mock()
|
||||
tool.fork_tool_runtime.return_value = "runtime-tool"
|
||||
controller = SimpleNamespace(
|
||||
get_tool=Mock(return_value=tool),
|
||||
need_credentials=True,
|
||||
get_credentials_schema_by_type=Mock(return_value=[]),
|
||||
)
|
||||
builtin_provider = SimpleNamespace(
|
||||
id="cred-1",
|
||||
credential_type=CredentialType.API_KEY.value,
|
||||
credentials={"encrypted": "value"},
|
||||
expires_at=-1,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
builtin_provider
|
||||
)
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key": "secret"}
|
||||
cache = Mock()
|
||||
with patch("core.tools.tool_manager.create_provider_encrypter", return_value=(encrypter, cache)):
|
||||
result = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="time",
|
||||
tool_name="weekday",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
assert result == "runtime-tool"
|
||||
runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"]
|
||||
assert runtime.credentials == {"api_key": "secret"}
|
||||
assert runtime.credential_type == CredentialType.API_KEY
|
||||
|
||||
|
||||
@patch("core.tools.tool_manager.create_provider_encrypter")
|
||||
@patch("core.plugin.impl.oauth.OAuthHandler")
|
||||
@patch(
|
||||
"services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client",
|
||||
return_value={"client_id": "id"},
|
||||
)
|
||||
@patch("core.tools.tool_manager.db")
|
||||
@patch("core.tools.tool_manager.time.time", return_value=1000)
|
||||
@patch("core.helper.credential_utils.check_credential_policy_compliance")
|
||||
def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
|
||||
mock_check,
|
||||
mock_time,
|
||||
mock_db,
|
||||
mock_get_oauth_client,
|
||||
mock_oauth_handler_cls,
|
||||
mock_create_provider_encrypter,
|
||||
):
|
||||
tool = Mock()
|
||||
tool.fork_tool_runtime.return_value = "runtime-tool"
|
||||
controller = SimpleNamespace(
|
||||
get_tool=Mock(return_value=tool),
|
||||
need_credentials=True,
|
||||
get_credentials_schema_by_type=Mock(return_value=[]),
|
||||
)
|
||||
builtin_provider = SimpleNamespace(
|
||||
id="cred-1",
|
||||
credential_type=CredentialType.OAUTH2.value,
|
||||
credentials={"encrypted": "value"},
|
||||
encrypted_credentials=None,
|
||||
expires_at=1,
|
||||
user_id="user-1",
|
||||
)
|
||||
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"token": "old"}
|
||||
encrypter.encrypt.return_value = {"token": "encrypted"}
|
||||
cache = Mock()
|
||||
mock_create_provider_encrypter.return_value = (encrypter, cache)
|
||||
mock_oauth_handler_cls.return_value.refresh_credentials.return_value = refreshed
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
result = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="time",
|
||||
tool_name="weekday",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
assert result == "runtime-tool"
|
||||
assert builtin_provider.expires_at == refreshed.expires_at
|
||||
assert builtin_provider.encrypted_credentials == json.dumps({"token": "encrypted"})
|
||||
mock_db.session.commit.assert_called_once()
|
||||
cache.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tool_runtime_builtin_plugin_provider_deleted_raises():
|
||||
plugin_controller = object.__new__(PluginToolProviderController)
|
||||
plugin_controller.entity = SimpleNamespace(credentials_schema=[{"name": "k"}], oauth_schema=None)
|
||||
plugin_controller.get_tool = Mock(return_value=Mock())
|
||||
plugin_controller.get_credentials_schema_by_type = Mock(return_value=[])
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_controller):
|
||||
with patch("core.tools.tool_manager.is_valid_uuid", return_value=True):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="provider has been deleted"):
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="time",
|
||||
tool_name="weekday",
|
||||
tenant_id="tenant-1",
|
||||
credential_id="uuid-id",
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_api_path():
|
||||
api_tool = Mock()
|
||||
api_tool.fork_tool_runtime.return_value = "api-runtime"
|
||||
api_provider = Mock()
|
||||
api_provider.get_tool.return_value = api_tool
|
||||
|
||||
with patch.object(ToolManager, "get_api_provider_controller", return_value=(api_provider, {"c": "enc"})):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"c": "dec"}
|
||||
with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())):
|
||||
assert (
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
== "api-runtime"
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_workflow_path():
|
||||
workflow_provider = SimpleNamespace(tenant_id="tenant-1")
|
||||
workflow_tool = Mock()
|
||||
workflow_tool.fork_tool_runtime.return_value = "wf-runtime"
|
||||
workflow_controller = Mock()
|
||||
workflow_controller.get_tools.return_value = [workflow_tool]
|
||||
session = Mock()
|
||||
session.begin.return_value = _cm(None)
|
||||
session.scalar.return_value = workflow_provider
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
|
||||
with patch(
|
||||
"core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller",
|
||||
return_value=workflow_controller,
|
||||
):
|
||||
assert (
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_id="wf-1",
|
||||
tool_name="wf",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
== "wf-runtime"
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_plugin_path():
|
||||
with patch.object(
|
||||
ToolManager,
|
||||
"get_plugin_provider",
|
||||
return_value=SimpleNamespace(get_tool=lambda _: "plugin-tool"),
|
||||
):
|
||||
assert (
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.PLUGIN,
|
||||
provider_id="plugin-1",
|
||||
tool_name="p",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
== "plugin-tool"
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_mcp_path():
|
||||
with patch.object(
|
||||
ToolManager,
|
||||
"get_mcp_provider_controller",
|
||||
return_value=SimpleNamespace(get_tool=lambda _: "mcp-tool"),
|
||||
):
|
||||
assert (
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.MCP,
|
||||
provider_id="mcp-1",
|
||||
tool_name="m",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
== "mcp-tool"
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_app_not_implemented():
|
||||
with pytest.raises(NotImplementedError, match="app provider not implemented"):
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.APP,
|
||||
provider_id="app",
|
||||
tool_name="x",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_agent_runtime_apply_runtime_parameters():
|
||||
parameter = ToolParameter.get_simple_instance(
|
||||
name="query",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
parameter.form = ToolParameter.ToolParameterForm.FORM
|
||||
|
||||
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
|
||||
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager):
|
||||
agent_tool = SimpleNamespace(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tool_parameters={"query": "hello"},
|
||||
credential_id=None,
|
||||
)
|
||||
result = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
agent_tool=agent_tool,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert result is tool_runtime
|
||||
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
|
||||
|
||||
|
||||
def test_get_workflow_runtime_apply_runtime_parameters():
|
||||
parameter = ToolParameter.get_simple_instance(
|
||||
name="query",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
parameter.form = ToolParameter.ToolParameterForm.FORM
|
||||
|
||||
workflow_tool = SimpleNamespace(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tool_configurations={"query": "hello"},
|
||||
credential_id=None,
|
||||
)
|
||||
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
|
||||
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager):
|
||||
workflow_result = ToolManager.get_workflow_tool_runtime(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
workflow_tool=workflow_tool,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert workflow_result is tool_runtime2
|
||||
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
|
||||
|
||||
|
||||
def test_get_agent_runtime_raises_when_runtime_missing():
|
||||
tool_runtime = SimpleNamespace(runtime=None, get_merged_runtime_parameters=lambda: [])
|
||||
agent_tool = SimpleNamespace(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tool_parameters={},
|
||||
credential_id=None,
|
||||
)
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={}):
|
||||
with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=Mock()):
|
||||
with pytest.raises(ValueError, match="runtime not found"):
|
||||
ToolManager.get_agent_tool_runtime(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
agent_tool=agent_tool,
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
|
||||
form_param = ToolParameter.get_simple_instance(
|
||||
name="q",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
form_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
llm_param = ToolParameter.get_simple_instance(
|
||||
name="llm",
|
||||
llm_description="llm",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
llm_param.form = ToolParameter.ToolParameterForm.LLM
|
||||
|
||||
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
|
||||
result = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type=ToolProviderType.API,
|
||||
tenant_id="tenant-1",
|
||||
provider="api-1",
|
||||
tool_name="search",
|
||||
tool_parameters={"q": "hello", "llm": "ignore"},
|
||||
)
|
||||
|
||||
assert result is tool_entity
|
||||
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
|
||||
|
||||
|
||||
def test_hardcoded_provider_icon_success():
|
||||
provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg")))
|
||||
with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider):
|
||||
with patch("core.tools.tool_manager.path.exists", return_value=True):
|
||||
with patch("core.tools.tool_manager.mimetypes.guess_type", return_value=("image/svg+xml", None)):
|
||||
icon_path, mime = ToolManager.get_hardcoded_provider_icon("time")
|
||||
assert icon_path.endswith("icon.svg")
|
||||
assert mime == "image/svg+xml"
|
||||
|
||||
|
||||
def test_hardcoded_provider_icon_missing_raises():
|
||||
provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg")))
|
||||
with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider):
|
||||
with patch("core.tools.tool_manager.path.exists", return_value=False):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="icon not found"):
|
||||
ToolManager.get_hardcoded_provider_icon("time")
|
||||
|
||||
|
||||
def test_list_hardcoded_providers_cache_hit():
|
||||
ToolManager._hardcoded_providers = {"p": Mock()}
|
||||
ToolManager._builtin_providers_loaded = True
|
||||
assert list(ToolManager.list_hardcoded_providers()) == list(ToolManager._hardcoded_providers.values())
|
||||
|
||||
|
||||
def test_clear_hardcoded_providers_cache_resets():
|
||||
ToolManager._hardcoded_providers = {"p": Mock()}
|
||||
ToolManager._builtin_providers_loaded = True
|
||||
ToolManager.clear_hardcoded_providers_cache()
|
||||
assert ToolManager._hardcoded_providers == {}
|
||||
assert ToolManager._builtin_providers_loaded is False
|
||||
|
||||
|
||||
def test_list_hardcoded_providers_internal_loader():
|
||||
good_provider = SimpleNamespace(
|
||||
entity=SimpleNamespace(identity=SimpleNamespace(name="good")),
|
||||
get_tools=lambda: [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="tool-a", label="A")))],
|
||||
)
|
||||
provider_class = Mock(return_value=good_provider)
|
||||
|
||||
with patch("core.tools.tool_manager.listdir", return_value=["good", "bad", "__skip"]):
|
||||
with patch("core.tools.tool_manager.path.isdir", side_effect=lambda p: "good" in p or "bad" in p):
|
||||
with patch(
|
||||
"core.tools.tool_manager.load_single_subclass_from_source",
|
||||
side_effect=[provider_class, RuntimeError("boom")],
|
||||
):
|
||||
ToolManager._hardcoded_providers = {}
|
||||
ToolManager._builtin_tools_labels = {}
|
||||
providers = list(ToolManager._list_hardcoded_providers())
|
||||
|
||||
assert providers == [good_provider]
|
||||
assert ToolManager._hardcoded_providers["good"] is good_provider
|
||||
assert ToolManager._builtin_tools_labels["tool-a"] == "A"
|
||||
assert ToolManager._builtin_providers_loaded is True
|
||||
|
||||
|
||||
def test_get_tool_label_loads_cache_and_handles_missing():
|
||||
ToolManager._builtin_tools_labels = {}
|
||||
|
||||
def _load():
|
||||
ToolManager._builtin_tools_labels["tool-a"] = "Label A"
|
||||
|
||||
with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load):
|
||||
assert ToolManager.get_tool_label("tool-a") == "Label A"
|
||||
assert ToolManager.get_tool_label("missing") is None
|
||||
|
||||
|
||||
def test_list_default_builtin_providers_for_postgres_and_mysql():
|
||||
provider_records = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
|
||||
|
||||
for scheme in ("postgresql", "mysql"):
|
||||
session = Mock()
|
||||
session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
|
||||
session.query.return_value.where.return_value.all.return_value = provider_records
|
||||
|
||||
with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
|
||||
providers = ToolManager.list_default_builtin_providers("tenant-1")
|
||||
|
||||
assert providers == provider_records
|
||||
|
||||
|
||||
def test_list_providers_from_api_covers_builtin_api_workflow_and_mcp(monkeypatch):
|
||||
hardcoded_controller = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="hardcoded")))
|
||||
plugin_controller = object.__new__(PluginToolProviderController)
|
||||
plugin_controller.entity = SimpleNamespace(identity=SimpleNamespace(name="plugin-provider"))
|
||||
|
||||
api_db_provider_good = SimpleNamespace(id="api-1")
|
||||
api_db_provider_bad = SimpleNamespace(id="api-2")
|
||||
api_controller = SimpleNamespace(provider_id="api-1")
|
||||
|
||||
workflow_db_provider_good = SimpleNamespace(id="wf-1")
|
||||
workflow_db_provider_bad = SimpleNamespace(id="wf-2")
|
||||
workflow_controller = SimpleNamespace(provider_id="wf-1")
|
||||
|
||||
session = Mock()
|
||||
session.scalars.side_effect = [
|
||||
SimpleNamespace(all=lambda: [api_db_provider_good, api_db_provider_bad]),
|
||||
SimpleNamespace(all=lambda: [workflow_db_provider_good, workflow_db_provider_bad]),
|
||||
]
|
||||
|
||||
_setup_list_providers_from_api_mocks(
|
||||
monkeypatch,
|
||||
session=session,
|
||||
hardcoded_controller=hardcoded_controller,
|
||||
plugin_controller=plugin_controller,
|
||||
api_controller=api_controller,
|
||||
workflow_controller=workflow_controller,
|
||||
)
|
||||
providers = ToolManager.list_providers_from_api(user_id="user-1", tenant_id="tenant-1", typ="")
|
||||
|
||||
names = {provider.name for provider in providers}
|
||||
assert {"hardcoded", "plugin-provider", "api-provider", "workflow-provider", "mcp-provider"} <= names
|
||||
|
||||
|
||||
def test_get_api_provider_controller_returns_controller_and_credentials():
|
||||
provider = SimpleNamespace(
|
||||
id="api-1",
|
||||
tenant_id="tenant-1",
|
||||
name="api-provider",
|
||||
description="desc",
|
||||
credentials={"auth_type": "api_key_query"},
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}',
|
||||
schema_type="openapi",
|
||||
schema="schema",
|
||||
tools=[],
|
||||
icon='{"background": "#000", "content": "A"}',
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
with patch(
|
||||
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
|
||||
) as mock_from_db:
|
||||
built_controller, credentials = ToolManager.get_api_provider_controller("tenant-1", "api-1")
|
||||
|
||||
assert built_controller is controller
|
||||
assert credentials == provider.credentials
|
||||
mock_from_db.assert_called_with(provider, ApiProviderAuthType.API_KEY_QUERY)
|
||||
controller.load_bundled_tools.assert_called_once_with(provider.tools)
|
||||
|
||||
|
||||
def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
provider = SimpleNamespace(
|
||||
id="api-1",
|
||||
tenant_id="tenant-1",
|
||||
name="api-provider",
|
||||
description="desc",
|
||||
credentials={"auth_type": "api_key_query"},
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}',
|
||||
schema_type="openapi",
|
||||
schema="schema",
|
||||
tools=[],
|
||||
icon='{"background": "#000", "content": "A"}',
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key_value": "secret"}
|
||||
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
|
||||
with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())):
|
||||
with patch("core.tools.tool_manager.ToolLabelManager.get_tool_labels", return_value=["search"]):
|
||||
user_payload = ToolManager.user_get_api_provider("api-provider", "tenant-1")
|
||||
|
||||
assert user_payload["credentials"]["api_key_value"] == "***"
|
||||
assert user_payload["labels"] == ["search"]
|
||||
|
||||
|
||||
def test_get_api_provider_controller_not_found_raises():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
|
||||
ToolManager.get_api_provider_controller("tenant-1", "missing")
|
||||
|
||||
|
||||
def test_get_mcp_provider_controller_returns_controller():
|
||||
provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"})
|
||||
controller = Mock()
|
||||
session = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
|
||||
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
|
||||
mock_service = mock_service_cls.return_value
|
||||
mock_service.get_provider.return_value = provider_entity
|
||||
with patch("core.tools.tool_manager.MCPToolProviderController.from_db", return_value=controller):
|
||||
built = ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1")
|
||||
assert built is controller
|
||||
|
||||
|
||||
def test_generate_mcp_tool_icon_url_returns_provider_icon():
|
||||
provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"})
|
||||
session = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
|
||||
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
|
||||
mock_service = mock_service_cls.return_value
|
||||
mock_service.get_provider_entity.return_value = provider_entity
|
||||
assert ToolManager.generate_mcp_tool_icon_url("tenant-1", "mcp-1") == provider_entity.provider_icon
|
||||
|
||||
|
||||
def test_get_mcp_provider_controller_missing_raises():
|
||||
session = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.tool_manager.Session", return_value=_cm(session)):
|
||||
with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls:
|
||||
mock_service_cls.return_value.get_provider.side_effect = ValueError("missing")
|
||||
with pytest.raises(ToolProviderNotFoundError, match="mcp provider mcp-1 not found"):
|
||||
ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1")
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_for_builtin_and_plugin():
|
||||
with patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "https://console.example.com"):
|
||||
builtin_url = ToolManager.generate_builtin_tool_icon_url("time")
|
||||
plugin_url = ToolManager.generate_plugin_tool_icon_url("tenant-1", "icon.svg")
|
||||
|
||||
assert builtin_url.endswith("/tool-provider/builtin/time/icon")
|
||||
assert "/plugin/icon" in plugin_url
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_for_workflow_and_api():
|
||||
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
|
||||
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
|
||||
|
||||
def test_get_tool_icon_for_builtin_provider_variants():
|
||||
plugin_provider = object.__new__(PluginToolProviderController)
|
||||
plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg"))
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_provider):
|
||||
with patch.object(ToolManager, "generate_plugin_tool_icon_url", return_value="plugin-icon"):
|
||||
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "plugin-provider") == "plugin-icon"
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=SimpleNamespace()):
|
||||
with patch.object(ToolManager, "generate_builtin_tool_icon_url", return_value="builtin-icon"):
|
||||
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "time") == "builtin-icon"
|
||||
|
||||
|
||||
def test_get_tool_icon_for_api_workflow_and_mcp():
|
||||
with patch.object(ToolManager, "generate_api_tool_icon_url", return_value={"background": "#000"}):
|
||||
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.API, "api-1") == {"background": "#000"}
|
||||
|
||||
with patch.object(ToolManager, "generate_workflow_tool_icon_url", return_value={"background": "#111"}):
|
||||
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.WORKFLOW, "wf-1") == {"background": "#111"}
|
||||
|
||||
with patch.object(ToolManager, "generate_mcp_tool_icon_url", return_value={"background": "#222"}):
|
||||
assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.MCP, "mcp-1") == {"background": "#222"}
|
||||
|
||||
|
||||
def test_get_tool_icon_plugin_error_returns_default():
|
||||
plugin_provider = object.__new__(PluginToolProviderController)
|
||||
plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg"))
|
||||
|
||||
with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider):
|
||||
with patch.object(ToolManager, "generate_plugin_tool_icon_url", side_effect=RuntimeError("fail")):
|
||||
icon = ToolManager.get_tool_icon("tenant-1", ToolProviderType.PLUGIN, "plugin-provider")
|
||||
assert icon["background"] == "#252525"
|
||||
|
||||
|
||||
def test_get_tool_icon_invalid_provider_type_raises():
|
||||
with pytest.raises(ValueError, match="provider type"):
|
||||
ToolManager.get_tool_icon("tenant-1", "invalid", "x") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_convert_tool_parameters_type_agent_and_workflow_branches():
|
||||
file_param = ToolParameter.get_simple_instance(
|
||||
name="file",
|
||||
llm_description="file",
|
||||
typ=ToolParameter.ToolParameterType.FILE,
|
||||
required=True,
|
||||
)
|
||||
file_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
|
||||
with pytest.raises(ValueError, match="file type parameter file not supported in agent"):
|
||||
ToolManager._convert_tool_parameters_type(
|
||||
parameters=[file_param],
|
||||
variable_pool=None,
|
||||
tool_configurations={"file": "x"},
|
||||
typ="agent",
|
||||
)
|
||||
|
||||
text_param = ToolParameter.get_simple_instance(
|
||||
name="text",
|
||||
llm_description="text",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
text_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
plain = ToolManager._convert_tool_parameters_type(
|
||||
parameters=[text_param],
|
||||
variable_pool=None,
|
||||
tool_configurations={"text": "hello"},
|
||||
typ="workflow",
|
||||
)
|
||||
assert plain == {"text": "hello"}
|
||||
|
||||
variable_pool = Mock()
|
||||
variable_pool.get.return_value = SimpleNamespace(value="from-variable")
|
||||
variable_pool.convert_template.return_value = SimpleNamespace(text="from-template")
|
||||
|
||||
mixed = ToolManager._convert_tool_parameters_type(
|
||||
parameters=[text_param],
|
||||
variable_pool=variable_pool,
|
||||
tool_configurations={"text": {"type": "mixed", "value": "Hello {{name}}"}},
|
||||
typ="workflow",
|
||||
)
|
||||
assert mixed == {"text": "from-template"}
|
||||
|
||||
variable = ToolManager._convert_tool_parameters_type(
|
||||
parameters=[text_param],
|
||||
variable_pool=variable_pool,
|
||||
tool_configurations={"text": {"type": "variable", "value": ["sys", "query"]}},
|
||||
typ="workflow",
|
||||
)
|
||||
assert variable == {"text": "from-variable"}
|
||||
|
||||
|
||||
def test_convert_tool_parameters_type_constant_branch():
|
||||
text_param = ToolParameter.get_simple_instance(
|
||||
name="text",
|
||||
llm_description="text",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
text_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
variable_pool = Mock()
|
||||
|
||||
constant = ToolManager._convert_tool_parameters_type(
|
||||
parameters=[text_param],
|
||||
variable_pool=variable_pool,
|
||||
tool_configurations={"text": {"type": "constant", "value": "fixed"}},
|
||||
typ="workflow",
|
||||
)
|
||||
|
||||
assert constant == {"text": "fixed"}
|
||||
110
api/tests/unit_tests/core/tools/test_tool_provider_controller.py
Normal file
110
api/tests/unit_tests/core/tools/test_tool_provider_controller.py
Normal file
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class _DummyTool(Tool):
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
class _DummyController(ToolProviderController):
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name=tool_name,
|
||||
label=I18nObject(en_US=tool_name),
|
||||
provider="provider",
|
||||
),
|
||||
parameters=[],
|
||||
)
|
||||
return _DummyTool(entity=entity, runtime=ToolRuntime(tenant_id="tenant"))
|
||||
|
||||
|
||||
def _provider_identity() -> ToolProviderIdentity:
|
||||
return ToolProviderIdentity(
|
||||
author="author",
|
||||
name="provider",
|
||||
description=I18nObject(en_US="desc"),
|
||||
icon="icon.svg",
|
||||
label=I18nObject(en_US="Provider"),
|
||||
)
|
||||
|
||||
|
||||
def test_tool_provider_controller_get_credentials_schema_returns_deep_copy():
|
||||
entity = ToolProviderEntity(
|
||||
identity=_provider_identity(),
|
||||
credentials_schema=[ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="api_key", required=False)],
|
||||
)
|
||||
controller = _DummyController(entity=entity)
|
||||
|
||||
schema = controller.get_credentials_schema()
|
||||
schema[0].name = "changed"
|
||||
|
||||
assert controller.entity.credentials_schema[0].name == "api_key"
|
||||
|
||||
|
||||
def test_tool_provider_controller_default_provider_type():
|
||||
entity = ToolProviderEntity(identity=_provider_identity(), credentials_schema=[])
|
||||
controller = _DummyController(entity=entity)
|
||||
|
||||
assert controller.provider_type == ToolProviderType.BUILT_IN
|
||||
|
||||
|
||||
def test_validate_credentials_format_covers_required_default_and_type_rules():
|
||||
select_options = [ProviderConfig.Option(value="opt-a", label=I18nObject(en_US="A"))]
|
||||
entity = ToolProviderEntity(
|
||||
identity=_provider_identity(),
|
||||
credentials_schema=[
|
||||
ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="required_text", required=True),
|
||||
ProviderConfig(type=ProviderConfig.Type.SECRET_INPUT, name="secret", required=False),
|
||||
ProviderConfig(type=ProviderConfig.Type.SELECT, name="choice", required=False, options=select_options),
|
||||
ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="with_default", required=False, default="x"),
|
||||
],
|
||||
)
|
||||
controller = _DummyController(entity=entity)
|
||||
|
||||
credentials = {"required_text": "value", "secret": None, "choice": "opt-a"}
|
||||
controller.validate_credentials_format(credentials)
|
||||
assert credentials["with_default"] == "x"
|
||||
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="not found"):
|
||||
controller.validate_credentials_format({"required_text": "value", "unknown": "v"})
|
||||
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="is required"):
|
||||
controller.validate_credentials_format({"secret": "s"})
|
||||
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="should be string"):
|
||||
controller.validate_credentials_format({"required_text": 123}) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="should be one of"):
|
||||
controller.validate_credentials_format({"required_text": "value", "choice": "opt-b"})
|
||||
148
api/tests/unit_tests/core/tools/utils/test_configuration.py
Normal file
148
api/tests/unit_tests/core/tools/utils/test_configuration.py
Normal file
@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
|
||||
|
||||
class _DummyTool(Tool):
|
||||
runtime_overrides: list[ToolParameter]
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, runtime_overrides: list[ToolParameter]):
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
self.runtime_overrides = runtime_overrides
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
return self.runtime_overrides
|
||||
|
||||
|
||||
def _param(
|
||||
name: str,
|
||||
*,
|
||||
typ: ToolParameter.ToolParameterType,
|
||||
form: ToolParameter.ToolParameterForm,
|
||||
required: bool = False,
|
||||
) -> ToolParameter:
|
||||
return ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name),
|
||||
placeholder=I18nObject(en_US=""),
|
||||
human_description=I18nObject(en_US=""),
|
||||
type=typ,
|
||||
form=form,
|
||||
required=required,
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_manager() -> ToolParameterConfigurationManager:
|
||||
base_params = [
|
||||
_param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM),
|
||||
_param("plain", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM),
|
||||
]
|
||||
runtime_overrides = [
|
||||
_param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM),
|
||||
_param("runtime_only", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM),
|
||||
]
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="a", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
|
||||
parameters=base_params,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
tool = _DummyTool(entity=entity, runtime=runtime, runtime_overrides=runtime_overrides)
|
||||
return ToolParameterConfigurationManager(
|
||||
tenant_id="tenant-1",
|
||||
tool_runtime=tool,
|
||||
provider_name="provider-a",
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
identity_id="ID.1",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_and_mask_parameters():
|
||||
manager = _build_manager()
|
||||
|
||||
masked = manager.mask_tool_parameters({"secret": "abcdefghi", "plain": "x", "runtime_only": "y"})
|
||||
assert masked["secret"] == "ab*****hi"
|
||||
assert masked["plain"] == "x"
|
||||
assert masked["runtime_only"] == "y"
|
||||
|
||||
|
||||
def test_encrypt_tool_parameters():
|
||||
manager = _build_manager()
|
||||
|
||||
with patch("core.tools.utils.configuration.encrypter.encrypt_token", return_value="enc"):
|
||||
encrypted = manager.encrypt_tool_parameters({"secret": "raw", "plain": "x"})
|
||||
|
||||
assert encrypted["secret"] == "enc"
|
||||
assert encrypted["plain"] == "x"
|
||||
|
||||
|
||||
def test_decrypt_tool_parameters_cache_hit_and_miss():
|
||||
manager = _build_manager()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = {"secret": "cached"}
|
||||
assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"}
|
||||
cache.set.assert_not_called()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = None
|
||||
with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
|
||||
|
||||
assert decrypted["secret"] == "dec"
|
||||
cache.set.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_tool_parameters_cache():
|
||||
manager = _build_manager()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
cache_cls.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_configuration_manager_decrypt_suppresses_errors():
|
||||
manager = _build_manager()
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = None
|
||||
with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
|
||||
# decryption failure is suppressed, original value is retained.
|
||||
assert decrypted["secret"] == "enc"
|
||||
@ -1,10 +1,13 @@
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
|
||||
|
||||
# ---------------------------
|
||||
@ -13,13 +16,13 @@ from core.helper.provider_encryption import ProviderConfigEncrypter
|
||||
class NoopCache:
|
||||
"""Simple cache stub: always returns None, does nothing for set/delete."""
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Any | None:
|
||||
return None
|
||||
|
||||
def set(self, config):
|
||||
def set(self, config: Any) -> None:
|
||||
pass
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -179,3 +182,35 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
|
||||
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
|
||||
|
||||
assert out["password"] == "ENC_ERR"
|
||||
|
||||
|
||||
def test_create_tool_provider_encrypter_builds_cache_and_encrypter():
|
||||
basic_config = BasicProviderConfig(name="key", type=BasicProviderConfig.Type.TEXT_INPUT)
|
||||
credential_schema_item = SimpleNamespace(to_basic_provider_config=lambda: basic_config)
|
||||
controller = SimpleNamespace(
|
||||
provider_type=SimpleNamespace(value="builtin"),
|
||||
entity=SimpleNamespace(identity=SimpleNamespace(name="provider-a")),
|
||||
get_credentials_schema=lambda: [credential_schema_item],
|
||||
)
|
||||
|
||||
cache_instance = Mock()
|
||||
encrypter_instance = Mock()
|
||||
|
||||
with patch(
|
||||
"core.tools.utils.encryption.SingletonProviderCredentialsCache", return_value=cache_instance
|
||||
) as cache_cls:
|
||||
with patch("core.tools.utils.encryption.ProviderConfigEncrypter", return_value=encrypter_instance) as enc_cls:
|
||||
encrypter, cache = create_tool_provider_encrypter("tenant-1", controller)
|
||||
|
||||
assert encrypter is encrypter_instance
|
||||
assert cache is cache_instance
|
||||
cache_cls.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider_type="builtin",
|
||||
provider_identity="provider-a",
|
||||
)
|
||||
enc_cls.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
config=[basic_config],
|
||||
provider_config_cache=cache_instance,
|
||||
)
|
||||
|
||||
478
api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py
Normal file
478
api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py
Normal file
@ -0,0 +1,478 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from yaml import YAMLError
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.tools.utils.dataset_retriever import dataset_multi_retriever_tool as multi_retriever_module
|
||||
from core.tools.utils.dataset_retriever import dataset_retriever_tool as single_retriever_module
|
||||
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool as SingleDatasetRetrieverTool
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file, load_yaml_file_cached
|
||||
|
||||
|
||||
def _retrieve_config() -> DatasetRetrieveConfigEntity:
|
||||
return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE)
|
||||
|
||||
|
||||
class _FakeFlaskApp:
|
||||
def app_context(self):
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target=None, kwargs=None, **_kwargs):
|
||||
self._target = target
|
||||
self._kwargs = kwargs or {}
|
||||
|
||||
def start(self):
|
||||
if self._target is not None:
|
||||
self._target(**self._kwargs)
|
||||
|
||||
def join(self):
|
||||
return None
|
||||
|
||||
|
||||
class _TestHitCallback(DatasetIndexToolCallbackHandler):
|
||||
def __init__(self):
|
||||
self.queries: list[tuple[str, str]] = []
|
||||
self.documents: list[RagDocument] | None = None
|
||||
self.resources = None
|
||||
|
||||
def on_query(self, query: str, dataset_id: str):
|
||||
self.queries.append((query, dataset_id))
|
||||
|
||||
def on_tool_end(self, documents: list[RagDocument]):
|
||||
self.documents = documents
|
||||
|
||||
def return_retriever_resource_info(self, resource):
|
||||
self.resources = list(resource)
|
||||
|
||||
|
||||
def test_remove_leading_symbols_preserves_markdown_link_and_strips_punctuation():
|
||||
markdown = "[Example](https://example.com) content"
|
||||
assert remove_leading_symbols(markdown) == markdown
|
||||
|
||||
assert remove_leading_symbols("...Hello world") == "Hello world"
|
||||
|
||||
|
||||
def test_is_valid_uuid_handles_valid_invalid_and_empty_values():
|
||||
assert is_valid_uuid(str(uuid.uuid4())) is True
|
||||
assert is_valid_uuid("not-a-uuid") is False
|
||||
assert is_valid_uuid("") is False
|
||||
assert is_valid_uuid(None) is False
|
||||
|
||||
|
||||
def test_load_yaml_file_valid(tmp_path):
|
||||
valid_file = tmp_path / "valid.yaml"
|
||||
valid_file.write_text("a: 1\nb: two\n", encoding="utf-8")
|
||||
|
||||
loaded = _load_yaml_file(file_path=str(valid_file))
|
||||
|
||||
assert loaded == {"a": 1, "b": "two"}
|
||||
|
||||
|
||||
def test_load_yaml_file_missing(tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_load_yaml_file(file_path=str(tmp_path / "missing.yaml"))
|
||||
|
||||
|
||||
def test_load_yaml_file_invalid(tmp_path):
|
||||
invalid_file = tmp_path / "invalid.yaml"
|
||||
invalid_file.write_text("a: [1, 2\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(YAMLError):
|
||||
_load_yaml_file(file_path=str(invalid_file))
|
||||
|
||||
|
||||
def test_load_yaml_file_cached_hits(tmp_path):
|
||||
valid_file = tmp_path / "valid.yaml"
|
||||
valid_file.write_text("a: 1\nb: two\n", encoding="utf-8")
|
||||
|
||||
load_yaml_file_cached.cache_clear()
|
||||
assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"}
|
||||
|
||||
assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"}
|
||||
assert load_yaml_file_cached.cache_info().hits == 1
|
||||
|
||||
|
||||
def test_single_dataset_retriever_from_dataset_builds_name_and_description():
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="Knowledge", description=None)
|
||||
|
||||
tool = SingleDatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
retrieve_config=_retrieve_config(),
|
||||
return_resource=False,
|
||||
retriever_from="prod",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
assert tool.name == "dataset_dataset_1"
|
||||
assert tool.description == "useful for when you want to answer queries about the Knowledge"
|
||||
|
||||
|
||||
def test_single_dataset_retriever_external_run_returns_content_and_resources():
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
name="Knowledge Base",
|
||||
provider="external",
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model={},
|
||||
)
|
||||
callback = _TestHitCallback()
|
||||
dataset_retrieval = Mock()
|
||||
dataset_retrieval.get_metadata_filter_condition.return_value = (
|
||||
{"dataset-1": ["doc-a"]},
|
||||
{"logical_operator": "and"},
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = dataset
|
||||
external_documents = [
|
||||
{"content": "first", "metadata": {"document_id": "doc-a"}, "score": 0.9, "title": "Doc A"},
|
||||
{"content": "second", "metadata": {"document_id": "doc-b"}, "score": 0.8, "title": "Doc B"},
|
||||
]
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
retrieve_config=_retrieve_config(),
|
||||
return_resource=True,
|
||||
retriever_from="dev",
|
||||
hit_callbacks=[callback],
|
||||
inputs={"x": 1},
|
||||
)
|
||||
|
||||
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
|
||||
with patch.object(
|
||||
single_retriever_module.ExternalDatasetService,
|
||||
"fetch_external_knowledge_retrieval",
|
||||
return_value=external_documents,
|
||||
) as fetch_mock:
|
||||
result = tool.run(query="hello")
|
||||
|
||||
assert result == "first\nsecond"
|
||||
assert callback.queries == [("hello", "dataset-1")]
|
||||
assert callback.resources is not None
|
||||
resource_info = callback.resources
|
||||
assert [item.position for item in resource_info] == [1, 2]
|
||||
assert resource_info[0].dataset_id == "dataset-1"
|
||||
fetch_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_single_dataset_retriever_returns_empty_when_metadata_filter_finds_no_documents():
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
name="Knowledge Base",
|
||||
provider="internal",
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model=None,
|
||||
)
|
||||
dataset_retrieval = Mock()
|
||||
dataset_retrieval.get_metadata_filter_condition.return_value = ({"dataset-1": []}, {"logical_operator": "and"})
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
retrieve_config=_retrieve_config(),
|
||||
return_resource=False,
|
||||
retriever_from="prod",
|
||||
hit_callbacks=[_TestHitCallback()],
|
||||
inputs={},
|
||||
)
|
||||
|
||||
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
|
||||
with patch.object(single_retriever_module.RetrievalService, "retrieve") as retrieve_mock:
|
||||
result = tool.run(query="hello")
|
||||
|
||||
assert result == ""
|
||||
retrieve_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
name="Knowledge Base",
|
||||
provider="internal",
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model={
|
||||
"search_method": "semantic_search",
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.2,
|
||||
"reranking_enable": True,
|
||||
"reranking_model": {"reranking_provider_name": "provider", "reranking_model_name": "model"},
|
||||
"reranking_mode": "reranking_model",
|
||||
"weights": {"vector_setting": {"vector_weight": 0.6}},
|
||||
},
|
||||
)
|
||||
callback = _TestHitCallback()
|
||||
dataset_retrieval = Mock()
|
||||
dataset_retrieval.get_metadata_filter_condition.return_value = (None, None)
|
||||
low_segment = SimpleNamespace(
|
||||
id="seg-low",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-low",
|
||||
content="raw low",
|
||||
answer="low answer",
|
||||
hit_count=1,
|
||||
word_count=10,
|
||||
position=3,
|
||||
index_node_hash="hash-low",
|
||||
get_sign_content=lambda: "signed low",
|
||||
)
|
||||
high_segment = SimpleNamespace(
|
||||
id="seg-high",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-high",
|
||||
content="raw high",
|
||||
answer=None,
|
||||
hit_count=9,
|
||||
word_count=25,
|
||||
position=1,
|
||||
index_node_hash="hash-high",
|
||||
get_sign_content=lambda: "signed high",
|
||||
)
|
||||
records = [
|
||||
SimpleNamespace(segment=low_segment, score=0.2, summary="summary low"),
|
||||
SimpleNamespace(segment=high_segment, score=0.9, summary=None),
|
||||
]
|
||||
documents = [
|
||||
RagDocument(page_content="first", metadata={"doc_id": "node-low", "score": 0.2}),
|
||||
RagDocument(page_content="second", metadata={"doc_id": "node-high", "score": 0.9}),
|
||||
]
|
||||
lookup_doc_low = SimpleNamespace(
|
||||
id="doc-low", name="Document Low", data_source_type="upload_file", doc_metadata={"lang": "en"}
|
||||
)
|
||||
lookup_doc_high = SimpleNamespace(
|
||||
id="doc-high", name="Document High", data_source_type="notion", doc_metadata={"lang": "fr"}
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
|
||||
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
retrieve_config=_retrieve_config(),
|
||||
return_resource=True,
|
||||
retriever_from="dev",
|
||||
hit_callbacks=[callback],
|
||||
inputs={},
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval):
|
||||
with patch.object(single_retriever_module.RetrievalService, "retrieve", return_value=documents):
|
||||
with patch.object(
|
||||
single_retriever_module.RetrievalService,
|
||||
"format_retrieval_documents",
|
||||
return_value=records,
|
||||
):
|
||||
result = tool.run(query="hello")
|
||||
|
||||
assert result == "signed high\nsummary low\nquestion:signed low answer:low answer"
|
||||
assert callback.documents == documents
|
||||
assert callback.resources is not None
|
||||
resource_info = callback.resources
|
||||
assert [item.position for item in resource_info] == [1, 2]
|
||||
assert resource_info[0].segment_id == "seg-high"
|
||||
assert resource_info[0].hit_count == 9
|
||||
assert resource_info[1].summary == "summary low"
|
||||
assert resource_info[1].content == "question:raw low \nanswer:low answer"
|
||||
|
||||
|
||||
def test_multi_dataset_retriever_from_dataset_sets_tool_name():
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=["dataset-1"],
|
||||
tenant_id="tenant-1",
|
||||
reranking_provider_name="provider",
|
||||
reranking_model_name="model",
|
||||
return_resource=False,
|
||||
retriever_from="prod",
|
||||
)
|
||||
|
||||
assert tool.name == "dataset_tenant_1"
|
||||
|
||||
|
||||
def test_multi_dataset_retriever_retriever_returns_early_when_dataset_is_missing():
|
||||
callback = _TestHitCallback()
|
||||
all_documents: list[RagDocument] = []
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = None
|
||||
tool = DatasetMultiRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_ids=["dataset-1"],
|
||||
reranking_provider_name="provider",
|
||||
reranking_model_name="model",
|
||||
return_resource=False,
|
||||
retriever_from="prod",
|
||||
)
|
||||
|
||||
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
with patch.object(multi_retriever_module.RetrievalService, "retrieve") as retrieve_mock:
|
||||
result = tool._retriever(
|
||||
flask_app=_FakeFlaskApp(),
|
||||
dataset_id="dataset-1",
|
||||
query="hello",
|
||||
all_documents=all_documents,
|
||||
hit_callbacks=[callback],
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert all_documents == []
|
||||
assert callback.queries == []
|
||||
retrieve_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_multi_dataset_retriever_retriever_non_economy_uses_retrieval_model():
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model={
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 6,
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.4,
|
||||
"reranking_enable": False,
|
||||
"reranking_mode": None,
|
||||
"weights": {"balanced": True},
|
||||
},
|
||||
)
|
||||
callback = _TestHitCallback()
|
||||
documents = [RagDocument(page_content="retrieved", metadata={"doc_id": "node-1", "score": 0.4})]
|
||||
all_documents: list[RagDocument] = []
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = dataset
|
||||
tool = DatasetMultiRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_ids=["dataset-1"],
|
||||
reranking_provider_name="provider",
|
||||
reranking_model_name="model",
|
||||
return_resource=False,
|
||||
retriever_from="prod",
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
with patch.object(multi_retriever_module.RetrievalService, "retrieve", return_value=documents) as retrieve_mock:
|
||||
tool._retriever(
|
||||
flask_app=_FakeFlaskApp(),
|
||||
dataset_id="dataset-1",
|
||||
query="hello",
|
||||
all_documents=all_documents,
|
||||
hit_callbacks=[callback],
|
||||
)
|
||||
|
||||
assert all_documents == documents
|
||||
assert callback.queries == [("hello", "dataset-1")]
|
||||
retrieve_mock.assert_called_once_with(
|
||||
retrieval_method="semantic_search",
|
||||
dataset_id="dataset-1",
|
||||
query="hello",
|
||||
top_k=6,
|
||||
score_threshold=0.4,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights={"balanced": True},
|
||||
)
|
||||
|
||||
|
||||
def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
|
||||
callback = _TestHitCallback()
|
||||
tool = DatasetMultiRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
dataset_ids=["dataset-1", "dataset-2"],
|
||||
reranking_provider_name="provider",
|
||||
reranking_model_name="model",
|
||||
return_resource=True,
|
||||
retriever_from="dev",
|
||||
hit_callbacks=[callback],
|
||||
top_k=2,
|
||||
score_threshold=0.1,
|
||||
)
|
||||
first_doc = RagDocument(page_content="first", metadata={"doc_id": "node-2", "score": 0.4})
|
||||
second_doc = RagDocument(page_content="second", metadata={"doc_id": "node-1", "score": 0.9})
|
||||
|
||||
def fake_retriever(**kwargs):
|
||||
if kwargs["dataset_id"] == "dataset-1":
|
||||
kwargs["all_documents"].append(first_doc)
|
||||
else:
|
||||
kwargs["all_documents"].append(second_doc)
|
||||
|
||||
segment_for_node_2 = SimpleNamespace(
|
||||
id="seg-2",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-2",
|
||||
index_node_id="node-2",
|
||||
content="raw two",
|
||||
answer="answer two",
|
||||
hit_count=2,
|
||||
word_count=20,
|
||||
position=2,
|
||||
index_node_hash="hash-2",
|
||||
get_sign_content=lambda: "signed two",
|
||||
)
|
||||
segment_for_node_1 = SimpleNamespace(
|
||||
id="seg-1",
|
||||
dataset_id="dataset-2",
|
||||
document_id="doc-1",
|
||||
index_node_id="node-1",
|
||||
content="raw one",
|
||||
answer=None,
|
||||
hit_count=7,
|
||||
word_count=30,
|
||||
position=1,
|
||||
index_node_hash="hash-1",
|
||||
get_sign_content=lambda: "signed one",
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
|
||||
db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
SimpleNamespace(id="dataset-2", name="Dataset Two"),
|
||||
SimpleNamespace(id="dataset-1", name="Dataset One"),
|
||||
]
|
||||
db_session.scalar.side_effect = [
|
||||
SimpleNamespace(id="doc-1", name="Doc One", data_source_type="upload_file", doc_metadata={"p": 1}),
|
||||
SimpleNamespace(id="doc-2", name="Doc Two", data_source_type="notion", doc_metadata={"p": 2}),
|
||||
]
|
||||
model_manager = Mock()
|
||||
model_manager.get_model_instance.return_value = Mock()
|
||||
rerank_runner = Mock()
|
||||
rerank_runner.run.return_value = [second_doc, first_doc]
|
||||
fake_current_app = SimpleNamespace(_get_current_object=lambda: _FakeFlaskApp())
|
||||
|
||||
with patch.object(tool, "_retriever", side_effect=fake_retriever) as retriever_mock:
|
||||
with patch.object(multi_retriever_module, "current_app", fake_current_app):
|
||||
with patch.object(multi_retriever_module.threading, "Thread", _ImmediateThread):
|
||||
with patch.object(multi_retriever_module, "ModelManager", return_value=model_manager):
|
||||
with patch.object(multi_retriever_module, "RerankModelRunner", return_value=rerank_runner):
|
||||
with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)):
|
||||
result = tool.run(query="hello")
|
||||
|
||||
assert result == "signed one\nquestion:signed two answer:answer two"
|
||||
assert retriever_mock.call_count == 2
|
||||
assert callback.documents == [second_doc, first_doc]
|
||||
assert callback.resources is not None
|
||||
resource_info = callback.resources
|
||||
assert [item.position for item in resource_info] == [1, 2]
|
||||
assert resource_info[0].score == 0.9
|
||||
assert resource_info[0].content == "raw one"
|
||||
assert resource_info[1].score == 0.4
|
||||
assert resource_info[1].content == "question:raw two \nanswer:answer two"
|
||||
@ -0,0 +1,158 @@
|
||||
"""Unit tests for ModelInvocationUtils.
|
||||
|
||||
Covers success and error branches for ModelInvocationUtils, including
|
||||
InvokeModelError and invoke error mappings for InvokeAuthorizationError,
|
||||
InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, and
|
||||
InvokeServerUnavailableError. Assumes mocked model instances and managers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace:
|
||||
model_type_instance = Mock()
|
||||
model_type_instance.get_model_schema.return_value = (
|
||||
SimpleNamespace(model_properties=schema or {}) if schema is not None else None
|
||||
)
|
||||
return SimpleNamespace(
|
||||
provider="provider",
|
||||
model="model-a",
|
||||
model_name="model-a",
|
||||
credentials={"api_key": "x"},
|
||||
model_type_instance=model_type_instance,
|
||||
get_llm_num_tokens=lambda prompt_messages: 5,
|
||||
invoke_llm=Mock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_instance", "expected", "error_match"),
|
||||
[
|
||||
(None, None, "Model not found"),
|
||||
(_mock_model_instance(schema=None), None, "No model schema found"),
|
||||
(_mock_model_instance(schema={}), 2048, None),
|
||||
(_mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 8192}), 8192, None),
|
||||
],
|
||||
ids=[
|
||||
"missing-model",
|
||||
"missing-schema",
|
||||
"default-context-size",
|
||||
"schema-context-size",
|
||||
],
|
||||
)
|
||||
def test_get_max_llm_context_tokens_branches(model_instance, expected, error_match):
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
if error_match:
|
||||
with pytest.raises(InvokeModelError, match=error_match):
|
||||
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
|
||||
else:
|
||||
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
|
||||
|
||||
|
||||
def test_calculate_tokens_handles_missing_model():
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = None
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with pytest.raises(InvokeModelError, match="Model not found"):
|
||||
ModelInvocationUtils.calculate_tokens("tenant", [])
|
||||
|
||||
|
||||
def test_invoke_success_and_error_mappings():
|
||||
model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048})
|
||||
model_instance.invoke_llm.return_value = SimpleNamespace(
|
||||
message=SimpleNamespace(content="ok"),
|
||||
usage=SimpleNamespace(
|
||||
completion_tokens=7,
|
||||
completion_unit_price=Decimal("0.1"),
|
||||
completion_price_unit=Decimal(1),
|
||||
latency=0.3,
|
||||
total_price=Decimal("0.7"),
|
||||
currency="USD",
|
||||
),
|
||||
)
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
class _ToolModelInvoke:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
response = ModelInvocationUtils.invoke(
|
||||
user_id="u1",
|
||||
tenant_id="tenant",
|
||||
tool_type="builtin",
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
)
|
||||
|
||||
assert response.message.content == "ok"
|
||||
assert db_mock.session.add.call_count == 1
|
||||
assert db_mock.session.commit.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(InvokeRateLimitError("rate"), "Invoke rate limit error"),
|
||||
(InvokeBadRequestError("bad"), "Invoke bad request error"),
|
||||
(InvokeConnectionError("conn"), "Invoke connection error"),
|
||||
(InvokeAuthorizationError("auth"), "Invoke authorization error"),
|
||||
(InvokeServerUnavailableError("down"), "Invoke server unavailable error"),
|
||||
(RuntimeError("oops"), "Invoke error"),
|
||||
],
|
||||
ids=[
|
||||
"rate-limit",
|
||||
"bad-request",
|
||||
"connection",
|
||||
"authorization",
|
||||
"server-unavailable",
|
||||
"generic-error",
|
||||
],
|
||||
)
|
||||
def test_invoke_error_mappings(exc, expected):
|
||||
model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048})
|
||||
model_instance.invoke_llm.side_effect = exc
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
class _ToolModelInvoke:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
with pytest.raises(InvokeModelError, match=expected):
|
||||
ModelInvocationUtils.invoke(
|
||||
user_id="u1",
|
||||
tenant_id="tenant",
|
||||
tool_type="builtin",
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
)
|
||||
@ -1,6 +1,12 @@
|
||||
from json.decoder import JSONDecodeError
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from yaml import YAMLError
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
|
||||
|
||||
@ -189,3 +195,225 @@ def test_parse_openapi_to_tool_bundle_default_value_type_casting(app):
|
||||
available_param = params_by_name["available"]
|
||||
assert available_param.type == "boolean"
|
||||
assert available_param.default is True
|
||||
|
||||
|
||||
def test_sanitize_default_value_and_type_detection():
|
||||
assert ApiBasedToolSchemaParser._sanitize_default_value([]) is None
|
||||
assert ApiBasedToolSchemaParser._sanitize_default_value({}) is None
|
||||
assert ApiBasedToolSchemaParser._sanitize_default_value("ok") == "ok"
|
||||
|
||||
assert (
|
||||
ApiBasedToolSchemaParser._get_tool_parameter_type({"format": "binary"}) == ToolParameter.ToolParameterType.FILE
|
||||
)
|
||||
assert (
|
||||
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "integer"}) == ToolParameter.ToolParameterType.NUMBER
|
||||
)
|
||||
assert (
|
||||
ApiBasedToolSchemaParser._get_tool_parameter_type({"schema": {"type": "boolean"}})
|
||||
== ToolParameter.ToolParameterType.BOOLEAN
|
||||
)
|
||||
assert (
|
||||
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"format": "binary"}})
|
||||
== ToolParameter.ToolParameterType.FILES
|
||||
)
|
||||
assert (
|
||||
ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"type": "string"}})
|
||||
== ToolParameter.ToolParameterType.ARRAY
|
||||
)
|
||||
assert ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "object"}) is None
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_server_env_and_refs(app):
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "API", "version": "1.0.0", "description": "API description"},
|
||||
"servers": [
|
||||
{"url": "https://dev.example.com", "env": "dev"},
|
||||
{"url": "https://prod.example.com", "env": "prod"},
|
||||
],
|
||||
"paths": {
|
||||
"/items": {
|
||||
"post": {
|
||||
"description": "Create item",
|
||||
"parameters": [
|
||||
{"$ref": "#/components/parameters/token"},
|
||||
{"name": "token", "schema": {"type": "string"}},
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/ItemRequest"}}}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"parameters": {
|
||||
"token": {"name": "token", "required": True, "schema": {"type": "string"}},
|
||||
},
|
||||
"schemas": {
|
||||
"ItemRequest": {
|
||||
"type": "object",
|
||||
"required": ["age"],
|
||||
"properties": {"age": {"type": "integer", "description": "Age", "default": 18}},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
extra_info: dict = {}
|
||||
warning: dict = {}
|
||||
with app.test_request_context(headers={"X-Request-Env": "prod"}):
|
||||
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
assert len(bundles) == 1
|
||||
assert bundles[0].server_url == "https://prod.example.com/items"
|
||||
assert warning["duplicated_parameter"].startswith("Parameter token")
|
||||
assert extra_info["description"] == "API description"
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_no_server_raises(app):
|
||||
openapi = {"info": {"title": "x"}, "servers": [], "paths": {}}
|
||||
with app.test_request_context():
|
||||
with pytest.raises(ToolProviderNotFoundError, match="No server found"):
|
||||
ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
|
||||
def test_parse_openapi_yaml_to_tool_bundle_invalid_yaml(app):
|
||||
with app.test_request_context():
|
||||
with pytest.raises(ToolApiSchemaError, match="Invalid openapi yaml"):
|
||||
ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle("null")
|
||||
|
||||
|
||||
def test_parse_swagger_to_openapi_branches():
|
||||
with pytest.raises(ToolApiSchemaError, match="No server found"):
|
||||
ApiBasedToolSchemaParser.parse_swagger_to_openapi({"info": {}, "paths": {}})
|
||||
|
||||
with pytest.raises(ToolApiSchemaError, match="No paths found"):
|
||||
ApiBasedToolSchemaParser.parse_swagger_to_openapi({"servers": [{"url": "https://x"}], "paths": {}})
|
||||
|
||||
with pytest.raises(ToolApiSchemaError, match="No operationId found"):
|
||||
ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
{
|
||||
"servers": [{"url": "https://x"}],
|
||||
"paths": {"/a": {"get": {"summary": "x", "responses": {}}}},
|
||||
}
|
||||
)
|
||||
|
||||
warning: dict = {"seed": True}
|
||||
converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
{
|
||||
"servers": [{"url": "https://x"}],
|
||||
"paths": {"/a": {"get": {"operationId": "getA", "responses": {}}}},
|
||||
"definitions": {"A": {"type": "object"}},
|
||||
},
|
||||
warning=warning,
|
||||
)
|
||||
assert converted["openapi"] == "3.0.0"
|
||||
assert converted["components"]["schemas"]["A"]["type"] == "object"
|
||||
assert warning["missing_summary"].startswith("No summary or description found")
|
||||
|
||||
|
||||
def test_parse_openai_plugin_json_branches(app):
|
||||
with app.test_request_context():
|
||||
with pytest.raises(ToolProviderNotFoundError, match="Invalid openai plugin json"):
|
||||
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle("{bad")
|
||||
|
||||
with pytest.raises(ToolNotSupportedError, match="Only openapi is supported"):
|
||||
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
'{"api": {"url": "https://x", "type": "graphql"}}'
|
||||
)
|
||||
|
||||
|
||||
def test_parse_openai_plugin_json_http_branches(app):
|
||||
with app.test_request_context():
|
||||
response = type("Resp", (), {"status_code": 500, "text": "", "close": Mock()})()
|
||||
with patch("core.tools.utils.parser.httpx.get", return_value=response):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="cannot get openapi yaml"):
|
||||
ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
'{"api": {"url": "https://x", "type": "openapi"}}'
|
||||
)
|
||||
response.close.assert_called_once()
|
||||
|
||||
success_response = type("Resp", (), {"status_code": 200, "text": "openapi: 3.0.0", "close": Mock()})()
|
||||
with patch("core.tools.utils.parser.httpx.get", return_value=success_response):
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle",
|
||||
return_value=["bundle"],
|
||||
) as mock_parse:
|
||||
bundles = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
'{"api": {"url": "https://x", "type": "openapi"}}'
|
||||
)
|
||||
assert bundles == ["bundle"]
|
||||
mock_parse.assert_called_once()
|
||||
success_response.close.assert_called_once()
|
||||
|
||||
|
||||
def test_auto_parse_json_yaml_failure():
|
||||
with patch("core.tools.utils.parser.json_loads", side_effect=JSONDecodeError("bad", "x", 0)):
|
||||
with patch("core.tools.utils.parser.safe_load", side_effect=YAMLError("bad yaml")):
|
||||
with pytest.raises(ToolApiSchemaError, match="Invalid api schema, schema is neither json nor yaml"):
|
||||
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(":::")
|
||||
|
||||
|
||||
def test_auto_parse_openapi_success():
|
||||
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
|
||||
return_value=["openapi-bundle"],
|
||||
):
|
||||
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
|
||||
|
||||
assert bundles == ["openapi-bundle"]
|
||||
assert schema_type == ApiProviderSchemaType.OPENAPI
|
||||
|
||||
|
||||
def test_auto_parse_openapi_then_swagger():
|
||||
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
|
||||
loaded_content = {
|
||||
"openapi": "3.0.0",
|
||||
"servers": [{"url": "https://x"}],
|
||||
"info": {"title": "x"},
|
||||
"paths": {},
|
||||
}
|
||||
converted_swagger = {
|
||||
"openapi": "3.0.0",
|
||||
"servers": [{"url": "https://x"}],
|
||||
"info": {"title": "x"},
|
||||
"paths": {},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
|
||||
side_effect=[ToolApiSchemaError("openapi error"), ["swagger-bundle"]],
|
||||
) as mock_parse_openapi:
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi",
|
||||
return_value=converted_swagger,
|
||||
) as mock_parse_swagger:
|
||||
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
|
||||
|
||||
assert bundles == ["swagger-bundle"]
|
||||
assert schema_type == ApiProviderSchemaType.SWAGGER
|
||||
mock_parse_swagger.assert_called_once_with(loaded_content, extra_info={}, warning={})
|
||||
assert mock_parse_openapi.call_count == 2
|
||||
mock_parse_openapi.assert_any_call(loaded_content, extra_info={}, warning={})
|
||||
mock_parse_openapi.assert_any_call(converted_swagger, extra_info={}, warning={})
|
||||
|
||||
|
||||
def test_auto_parse_openapi_swagger_then_plugin():
|
||||
openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}'
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle",
|
||||
side_effect=ToolApiSchemaError("openapi error"),
|
||||
):
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi",
|
||||
side_effect=ToolApiSchemaError("swagger error"),
|
||||
):
|
||||
with patch(
|
||||
"core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle",
|
||||
return_value=["plugin-bundle"],
|
||||
):
|
||||
bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content)
|
||||
|
||||
assert bundles == ["plugin-bundle"]
|
||||
assert schema_type == ApiProviderSchemaType.OPENAI_PLUGIN
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils import system_oauth_encryption as oauth_encryption
|
||||
from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_roundtrip():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(payload)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert encrypted
|
||||
assert dict(decrypted) == payload
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_decrypt_validates_input():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
encrypter.decrypt_oauth_params("")
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError, match="Decryption failed"):
|
||||
encrypter.decrypt_oauth_params("not-base64")
|
||||
|
||||
|
||||
def test_system_oauth_helpers_use_global_cached_instance(monkeypatch):
|
||||
monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None)
|
||||
monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret")
|
||||
|
||||
first = oauth_encryption.get_system_oauth_encrypter()
|
||||
second = oauth_encryption.get_system_oauth_encrypter()
|
||||
assert first is second
|
||||
|
||||
encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"})
|
||||
assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"}
|
||||
|
||||
|
||||
def test_create_system_oauth_encrypter_factory():
|
||||
encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret")
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
|
||||
@ -31,3 +33,91 @@ def test_ensure_no_human_input_nodes_raises_for_human_input():
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
|
||||
|
||||
def test_get_workflow_graph_variables_and_outputs():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"data": {
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "query",
|
||||
"label": "Query",
|
||||
"type": "text-input",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"data": {
|
||||
"type": "end",
|
||||
"outputs": [
|
||||
{"variable": "answer", "value_type": "string", "value_selector": ["n1", "answer"]},
|
||||
{"variable": "score", "value_type": "number", "value_selector": ["n1", "score"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end-2",
|
||||
"data": {
|
||||
"type": "end",
|
||||
"outputs": [
|
||||
{"variable": "answer", "value_type": "object", "value_selector": ["n2", "answer"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
assert len(variables) == 1
|
||||
assert variables[0].variable == "query"
|
||||
assert variables[0].type == VariableEntityType.TEXT_INPUT
|
||||
|
||||
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
|
||||
assert [output.variable for output in outputs] == ["answer", "score"]
|
||||
assert outputs[0].value_type == "object"
|
||||
assert outputs[1].value_type == "number"
|
||||
|
||||
no_start = WorkflowToolConfigurationUtils.get_workflow_graph_variables({"nodes": []})
|
||||
assert no_start == []
|
||||
|
||||
|
||||
def test_check_is_synced_validation():
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="query",
|
||||
label="Query",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
configs = [
|
||||
WorkflowToolParameterConfiguration(
|
||||
name="query",
|
||||
description="desc",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
)
|
||||
]
|
||||
|
||||
WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=configs)
|
||||
|
||||
with pytest.raises(ValueError, match="parameter configuration mismatch"):
|
||||
WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=[])
|
||||
|
||||
with pytest.raises(ValueError, match="parameter configuration mismatch"):
|
||||
WorkflowToolConfigurationUtils.check_is_synced(
|
||||
variables=variables,
|
||||
tool_configurations=[
|
||||
WorkflowToolParameterConfiguration(
|
||||
name="other",
|
||||
description="desc",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
def _controller() -> WorkflowToolProviderController:
|
||||
entity = ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author="author",
|
||||
name="wf-provider",
|
||||
description=I18nObject(en_US="desc"),
|
||||
icon="icon.svg",
|
||||
label=I18nObject(en_US="WF"),
|
||||
),
|
||||
credentials_schema=[],
|
||||
)
|
||||
return WorkflowToolProviderController(entity=entity, provider_id="provider-1")
|
||||
|
||||
|
||||
def _mock_session_with_begin() -> Mock:
|
||||
session = Mock()
|
||||
begin_cm = Mock()
|
||||
begin_cm.__enter__ = Mock(return_value=None)
|
||||
begin_cm.__exit__ = Mock(return_value=False)
|
||||
session.begin.return_value = begin_cm
|
||||
return session
|
||||
|
||||
|
||||
def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
parameter_configurations=[
|
||||
SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM),
|
||||
SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM),
|
||||
],
|
||||
)
|
||||
user = SimpleNamespace(name="Alice")
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="country",
|
||||
label="Country",
|
||||
description="Country",
|
||||
type=VariableEntityType.SELECT,
|
||||
required=True,
|
||||
options=["US", "IN"],
|
||||
)
|
||||
]
|
||||
outputs = [
|
||||
SimpleNamespace(variable="json", value_type="string"),
|
||||
SimpleNamespace(variable="answer", value_type="string"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.tools.workflow_as_tool.provider.WorkflowAppConfigManager.convert_features",
|
||||
return_value=SimpleNamespace(file_upload=True),
|
||||
),
|
||||
patch(
|
||||
"core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_variables",
|
||||
return_value=variables,
|
||||
),
|
||||
patch(
|
||||
"core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_output",
|
||||
return_value=outputs,
|
||||
),
|
||||
):
|
||||
tool = controller._get_db_provider_tool(db_provider, app, session=session, user=user)
|
||||
|
||||
assert tool.entity.identity.name == "workflow_tool"
|
||||
# "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out.
|
||||
assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}}
|
||||
assert "json" not in tool.entity.output_schema["properties"]
|
||||
assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT
|
||||
assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
assert controller.provider_type == ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
def test_get_tool_returns_hit_or_none():
|
||||
controller = _controller()
|
||||
tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool")))
|
||||
controller.tools = [tool]
|
||||
|
||||
assert controller.get_tool("workflow_tool") is tool
|
||||
assert controller.get_tool("missing") is None
|
||||
|
||||
|
||||
def test_get_tools_returns_cached():
|
||||
controller = _controller()
|
||||
cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))]
|
||||
controller.tools = cached_tools # type: ignore[assignment]
|
||||
|
||||
assert controller.get_tools("tenant-1") == cached_tools
|
||||
|
||||
|
||||
def test_from_db_builds_controller():
|
||||
controller = _controller()
|
||||
|
||||
app = SimpleNamespace(id="app-1")
|
||||
user = SimpleNamespace(name="Alice")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
user_id="user-1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
parameter_configurations=[],
|
||||
)
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
fake_cm = MagicMock()
|
||||
fake_cm.__enter__.return_value = session
|
||||
fake_cm.__exit__.return_value = False
|
||||
fake_session_factory = Mock()
|
||||
fake_session_factory.create_session.return_value = fake_cm
|
||||
|
||||
with patch("core.tools.workflow_as_tool.provider.session_factory", fake_session_factory):
|
||||
with patch.object(
|
||||
WorkflowToolProviderController,
|
||||
"_get_db_provider_tool",
|
||||
return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))),
|
||||
):
|
||||
built = WorkflowToolProviderController.from_db(db_provider)
|
||||
assert isinstance(built, WorkflowToolProviderController)
|
||||
assert built.tools
|
||||
|
||||
|
||||
def test_get_tools_returns_empty_when_provider_missing():
|
||||
controller = _controller()
|
||||
controller.tools = None # type: ignore[assignment]
|
||||
|
||||
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
assert controller.get_tools("tenant-1") == []
|
||||
|
||||
|
||||
def test_get_tools_raises_when_app_missing():
|
||||
controller = _controller()
|
||||
controller.tools = None # type: ignore[assignment]
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
user_id="user-1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
parameter_configurations=[],
|
||||
)
|
||||
|
||||
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.get.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
controller.get_tools("tenant-1")
|
||||
@ -1,20 +1,85 @@
|
||||
"""Unit tests for workflow-as-tool behavior.
|
||||
|
||||
StubSession/StubScalars emulate SQLAlchemy session/scalars with minimal methods
|
||||
(`scalar`, `scalars`, `expunge`, `commit`, `refresh`, context manager) to keep
|
||||
database access mocked and predictable in tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
"""
|
||||
class StubScalars:
|
||||
"""Minimal stub for SQLAlchemy scalar results."""
|
||||
|
||||
_value: Any
|
||||
|
||||
def __init__(self, value: Any) -> None:
|
||||
self._value = value
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._value
|
||||
|
||||
|
||||
class StubSession:
|
||||
"""Minimal stub for session_factory-created sessions."""
|
||||
|
||||
scalar_results: list[Any]
|
||||
scalars_results: list[Any]
|
||||
expunge_calls: list[object]
|
||||
|
||||
def __init__(self, *, scalar_results: list[Any] | None = None, scalars_results: list[Any] | None = None) -> None:
|
||||
self.scalar_results = list(scalar_results or [])
|
||||
self.scalars_results = list(scalars_results or [])
|
||||
self.expunge_calls: list[object] = []
|
||||
|
||||
def scalar(self, _stmt: Any) -> Any:
|
||||
return self.scalar_results.pop(0)
|
||||
|
||||
def scalars(self, _stmt: Any) -> StubScalars:
|
||||
return StubScalars(self.scalars_results.pop(0))
|
||||
|
||||
def expunge(self, value: Any) -> None:
|
||||
self.expunge_calls.append(value)
|
||||
|
||||
def begin(self) -> "StubSession":
|
||||
return self
|
||||
|
||||
def commit(self) -> None:
|
||||
pass
|
||||
|
||||
def refresh(self, _value: Any) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "StubSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _build_tool() -> WorkflowTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
@ -22,9 +87,9 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
return WorkflowTool(
|
||||
workflow_app_id="app-1",
|
||||
workflow_as_tool_id="wf-tool-1",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
@ -32,13 +97,19 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
"""
|
||||
tool = _build_tool()
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
@ -56,28 +127,12 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
|
||||
|
||||
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
"""Ensure pause_state_config is passed as None."""
|
||||
tool = _build_tool()
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
@ -94,22 +149,7 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke
|
||||
|
||||
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should generate variable messages when there are outputs"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
tool = _build_tool()
|
||||
|
||||
# Mock workflow outputs
|
||||
mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}}
|
||||
@ -119,8 +159,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
@ -134,10 +172,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
|
||||
# Execute tool invocation
|
||||
messages = list(tool.invoke("test_user", {}))
|
||||
|
||||
# Verify generated messages
|
||||
# Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages
|
||||
assert len(messages) == 5
|
||||
|
||||
# Verify variable messages
|
||||
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
|
||||
assert len(variable_messages) == 3
|
||||
@ -151,7 +185,7 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
|
||||
# Verify text message
|
||||
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
|
||||
assert len(text_messages) == 1
|
||||
assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text
|
||||
assert json.loads(text_messages[0].message.text) == mock_outputs
|
||||
|
||||
# Verify JSON message
|
||||
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
|
||||
@ -161,30 +195,13 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch
|
||||
|
||||
def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should handle empty outputs correctly"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
tool = _build_tool()
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
@ -217,61 +234,32 @@ def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPat
|
||||
assert json_messages[0].message.json_object == {}
|
||||
|
||||
|
||||
def test_create_variable_message():
|
||||
"""Test the functionality of creating variable messages"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Test different types of variable values
|
||||
test_cases = [
|
||||
@pytest.mark.parametrize(
|
||||
("var_name", "var_value"),
|
||||
[
|
||||
("string_var", "test string"),
|
||||
("int_var", 42),
|
||||
("float_var", 3.14),
|
||||
("bool_var", True),
|
||||
("list_var", [1, 2, 3]),
|
||||
("dict_var", {"key": "value"}),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_create_variable_message(var_name, var_value):
|
||||
"""Create variable messages for multiple value types."""
|
||||
tool = _build_tool()
|
||||
|
||||
for var_name, var_value in test_cases:
|
||||
message = tool.create_variable_message(var_name, var_value)
|
||||
message = tool.create_variable_message(var_name, var_value)
|
||||
|
||||
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
|
||||
assert message.message.variable_name == var_name
|
||||
assert message.message.variable_value == var_value
|
||||
assert message.message.stream is False
|
||||
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
|
||||
assert message.message.variable_name == var_name
|
||||
assert message.message.variable_value == var_value
|
||||
assert message.message.stream is False
|
||||
|
||||
|
||||
def test_create_file_message_should_include_file_marker():
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
"""Ensure file message includes marker and meta payload."""
|
||||
tool = _build_tool()
|
||||
|
||||
file_obj = object()
|
||||
message = tool.create_file_message(file_obj) # type: ignore[arg-type]
|
||||
@ -284,103 +272,247 @@ def test_create_file_message_should_include_file_marker():
|
||||
def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure worker context can resolve EndUser when Account is missing."""
|
||||
|
||||
class StubSession:
|
||||
def __init__(self, results: list):
|
||||
self.results = results
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
# SQLAlchemy Session APIs used by code under test
|
||||
def expunge(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# support `with session_factory.create_session() as session:`
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
tenant = SimpleNamespace(id="tenant_id")
|
||||
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
|
||||
|
||||
# Monkeypatch session factory to return our stub session
|
||||
stub_session = StubSession(scalar_results=[tenant, None, end_user])
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: StubSession([tenant, None, end_user]),
|
||||
lambda: stub_session,
|
||||
)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
tool = _build_tool()
|
||||
tool.runtime.invoke_from = InvokeFrom.SERVICE_API
|
||||
tool.runtime.tenant_id = "tenant_id"
|
||||
|
||||
resolved_user = tool._resolve_user_from_database(user_id=end_user.id)
|
||||
|
||||
assert resolved_user is end_user
|
||||
assert stub_session.expunge_calls == [end_user]
|
||||
|
||||
|
||||
def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Return None if tenant cannot be found in worker context."""
|
||||
|
||||
class StubSession:
|
||||
def __init__(self, results: list):
|
||||
self.results = results
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
def expunge(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
# Monkeypatch session factory to return our stub session with no tenant
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: StubSession([None]),
|
||||
lambda: StubSession(scalar_results=[None]),
|
||||
)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
tool = _build_tool()
|
||||
tool.runtime.invoke_from = InvokeFrom.SERVICE_API
|
||||
tool.runtime.tenant_id = "missing_tenant"
|
||||
|
||||
resolved_user = tool._resolve_user_from_database(user_id="any")
|
||||
|
||||
assert resolved_user is None
|
||||
|
||||
|
||||
def test_workflow_tool_provider_type_and_fork_runtime():
|
||||
"""Verify provider type and forked runtime behavior."""
|
||||
tool = _build_tool()
|
||||
assert tool.tool_provider_type() == ToolProviderType.WORKFLOW
|
||||
assert tool.latest_usage.total_tokens == 0
|
||||
|
||||
forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2", invoke_from=InvokeFrom.DEBUGGER))
|
||||
assert isinstance(forked, WorkflowTool)
|
||||
assert forked.workflow_app_id == tool.workflow_app_id
|
||||
assert forked.runtime.tenant_id == "tenant-2"
|
||||
|
||||
|
||||
def test_derive_usage_from_top_level_usage_key():
|
||||
"""Derive usage from top-level usage dict."""
|
||||
usage = WorkflowTool._derive_usage_from_result({"usage": {"total_tokens": 12, "total_price": "0.2"}})
|
||||
assert usage.total_tokens == 12
|
||||
|
||||
|
||||
def test_derive_usage_from_metadata_usage():
|
||||
"""Derive usage from metadata usage dict."""
|
||||
metadata_usage = WorkflowTool._derive_usage_from_result({"metadata": {"usage": {"total_tokens": 7}}})
|
||||
assert metadata_usage.total_tokens == 7
|
||||
|
||||
|
||||
def test_derive_usage_from_totals():
|
||||
"""Derive usage from top-level totals fields."""
|
||||
totals_usage = WorkflowTool._derive_usage_from_result(
|
||||
{"total_tokens": "9", "total_price": "1.3", "currency": "USD"}
|
||||
)
|
||||
assert totals_usage.total_tokens == 9
|
||||
assert str(totals_usage.total_price) == "1.3"
|
||||
|
||||
|
||||
def test_derive_usage_from_empty():
|
||||
"""Default usage values when result is empty."""
|
||||
empty_usage = WorkflowTool._derive_usage_from_result({})
|
||||
assert empty_usage.total_tokens == 0
|
||||
|
||||
|
||||
def test_extract_usage_from_nested():
|
||||
"""Extract nested usage dict from result payloads."""
|
||||
nested = WorkflowTool._extract_usage_dict({"nested": [{"data": {"usage": {"total_tokens": 3}}}]})
|
||||
assert nested == {"total_tokens": 3}
|
||||
|
||||
|
||||
def test_invoke_raises_when_user_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Raise ToolInvokeError when user resolution fails."""
|
||||
tool = _build_tool()
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: None)
|
||||
|
||||
with pytest.raises(ToolInvokeError, match="User not found"):
|
||||
list(tool.invoke("missing", {}))
|
||||
|
||||
|
||||
def test_resolve_user_from_database_returns_account(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Resolve Account and set tenant in worker context."""
|
||||
tenant = SimpleNamespace(id="tenant_id")
|
||||
account = SimpleNamespace(id="account_id", current_tenant=None)
|
||||
session = StubSession(scalar_results=[tenant, account])
|
||||
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.session_factory.create_session", lambda: session)
|
||||
tool = _build_tool()
|
||||
tool.runtime.tenant_id = "tenant_id"
|
||||
|
||||
resolved = tool._resolve_user_from_database(user_id="account_id")
|
||||
assert resolved is account
|
||||
assert account.current_tenant is tenant
|
||||
assert session.expunge_calls == [account]
|
||||
|
||||
|
||||
def test_get_workflow_and_get_app_db_branches(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Cover workflow/app retrieval branches and error cases."""
|
||||
tool = _build_tool()
|
||||
latest_workflow = SimpleNamespace(id="wf-latest")
|
||||
specific_workflow = SimpleNamespace(id="wf-v1")
|
||||
app = SimpleNamespace(id="app-1")
|
||||
sessions = iter(
|
||||
[
|
||||
StubSession(scalar_results=[], scalars_results=[latest_workflow]),
|
||||
StubSession(scalar_results=[specific_workflow], scalars_results=[]),
|
||||
StubSession(scalar_results=[app], scalars_results=[]),
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: next(sessions),
|
||||
)
|
||||
|
||||
assert tool._get_workflow("app-1", "") is latest_workflow
|
||||
assert tool._get_workflow("app-1", "1") is specific_workflow
|
||||
assert tool._get_app("app-1") is app
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: StubSession(scalar_results=[None, None], scalars_results=[None]),
|
||||
)
|
||||
with pytest.raises(ValueError, match="workflow not found"):
|
||||
tool._get_workflow("app-1", "1")
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
tool._get_app("app-1")
|
||||
|
||||
|
||||
def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool:
|
||||
"""Build a WorkflowTool and stub merged runtime parameters for files/query."""
|
||||
tool = _build_tool()
|
||||
files_param = ToolParameter.get_simple_instance(
|
||||
name="files",
|
||||
llm_description="files",
|
||||
typ=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
required=False,
|
||||
)
|
||||
files_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
text_param = ToolParameter.get_simple_instance(
|
||||
name="query",
|
||||
llm_description="query",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
)
|
||||
text_param.form = ToolParameter.ToolParameterForm.FORM
|
||||
|
||||
monkeypatch.setattr(tool, "get_merged_runtime_parameters", lambda: [files_param, text_param])
|
||||
return tool
|
||||
|
||||
|
||||
def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Transform args into parameters and files payloads."""
|
||||
tool = _setup_transform_args_tool(monkeypatch)
|
||||
|
||||
params, files = tool._transform_args(
|
||||
{
|
||||
"query": "hello",
|
||||
"files": [
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"type": "image",
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "tool-1",
|
||||
"extension": ".png",
|
||||
},
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"type": "document",
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "upload-1",
|
||||
},
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"type": "document",
|
||||
"transfer_method": "remote_url",
|
||||
"remote_url": "https://example.com/a.pdf",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
assert params == {"query": "hello"}
|
||||
assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files)
|
||||
assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files)
|
||||
assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files)
|
||||
|
||||
|
||||
def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ignore invalid file entries while keeping params."""
|
||||
tool = _setup_transform_args_tool(monkeypatch)
|
||||
invalid_params, invalid_files = tool._transform_args({"query": "hello", "files": [{"invalid": True}]})
|
||||
assert invalid_params == {"query": "hello"}
|
||||
assert invalid_files == []
|
||||
|
||||
|
||||
def test_extract_files():
|
||||
"""Extract file outputs into result and file list."""
|
||||
tool = _build_tool()
|
||||
built_files = [
|
||||
SimpleNamespace(id="file-1"),
|
||||
SimpleNamespace(id="file-2"),
|
||||
]
|
||||
with patch("core.tools.workflow_as_tool.tool.build_from_mapping", side_effect=built_files):
|
||||
outputs = {
|
||||
"attachments": [
|
||||
{
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "r1",
|
||||
}
|
||||
],
|
||||
"single_file": {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "r2",
|
||||
},
|
||||
"text": "ok",
|
||||
}
|
||||
result, extracted_files = tool._extract_files(outputs)
|
||||
|
||||
assert result["text"] == "ok"
|
||||
assert len(extracted_files) == 2
|
||||
|
||||
|
||||
def test_update_file_mapping():
|
||||
"""Map tool/local file transfer methods into output shape."""
|
||||
tool = _build_tool()
|
||||
tool_file = tool._update_file_mapping({"transfer_method": "tool_file", "related_id": "tool-1"})
|
||||
assert tool_file["tool_file_id"] == "tool-1"
|
||||
local_file = tool._update_file_mapping({"transfer_method": "local_file", "related_id": "upload-1"})
|
||||
assert local_file["upload_file_id"] == "upload-1"
|
||||
|
||||
0
api/tests/unit_tests/tools/__init__.py
Normal file
0
api/tests/unit_tests/tools/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user