mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: unit test cases for core.variables, core.plugin, core.prompt module (#32637)
This commit is contained in:
parent
135b3a15a6
commit
07e19c0748
@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
0
api/tests/unit_tests/core/plugin/impl/__init__.py
Normal file
0
api/tests/unit_tests/core/plugin/impl/__init__.py
Normal file
91
api/tests/unit_tests/core/plugin/impl/test_agent_client.py
Normal file
91
api/tests/unit_tests/core/plugin/impl/test_agent_client.py
Normal file
@ -0,0 +1,91 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.agent import PluginAgentClient
|
||||
|
||||
|
||||
def _agent_provider(name: str = "agent") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
plugin_id="org/plugin",
|
||||
declaration=SimpleNamespace(
|
||||
identity=SimpleNamespace(name=name),
|
||||
strategies=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPluginAgentClient:
|
||||
def test_fetch_agent_strategy_providers(self, mocker):
|
||||
client = PluginAgentClient()
|
||||
provider = _agent_provider("remote")
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"declaration": {
|
||||
"identity": {"name": "remote"},
|
||||
"strategies": [{"identity": {"provider": "old"}}],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"][0]["declaration"]["strategies"][0]["identity"]["provider"] == "remote"
|
||||
return [provider]
|
||||
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = client.fetch_agent_strategy_providers("tenant-1")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result[0].declaration.identity.name == "org/plugin/remote"
|
||||
assert result[0].declaration.strategies[0].identity.provider == "org/plugin/remote"
|
||||
|
||||
def test_fetch_agent_strategy_provider(self, mocker):
|
||||
client = PluginAgentClient()
|
||||
provider = _agent_provider("provider")
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
assert transformer({"data": None}) == {"data": None}
|
||||
payload = {"data": {"declaration": {"strategies": [{"identity": {"provider": "old"}}]}}}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"]["declaration"]["strategies"][0]["identity"]["provider"] == "provider"
|
||||
return provider
|
||||
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = client.fetch_agent_strategy_provider("tenant-1", "org/plugin/provider")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result.declaration.identity.name == "org/plugin/provider"
|
||||
assert result.declaration.strategies[0].identity.provider == "org/plugin/provider"
|
||||
|
||||
def test_invoke_merges_chunks_and_passes_context(self, mocker):
|
||||
client = PluginAgentClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["raw"])
|
||||
)
|
||||
merge_mock = mocker.patch("core.plugin.impl.agent.merge_blob_chunks", return_value=["merged"])
|
||||
context = PluginInvokeContext()
|
||||
|
||||
result = client.invoke(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
agent_provider="org/plugin/provider",
|
||||
agent_strategy="router",
|
||||
agent_params={"k": "v"},
|
||||
conversation_id="conv-1",
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result == ["merged"]
|
||||
assert merge_mock.call_count == 1
|
||||
payload = stream_mock.call_args.kwargs["data"]
|
||||
assert payload["data"]["agent_strategy_provider"] == "provider"
|
||||
assert payload["context"] == context.model_dump()
|
||||
assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin"
|
||||
45
api/tests/unit_tests/core/plugin/impl/test_asset_manager.py
Normal file
45
api/tests/unit_tests/core/plugin/impl/test_asset_manager.py
Normal file
@ -0,0 +1,45 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
|
||||
class TestPluginAssetManager:
|
||||
def test_fetch_asset_success(self, mocker):
|
||||
manager = PluginAssetManager()
|
||||
response = MagicMock(status_code=200, content=b"asset-bytes")
|
||||
request_mock = mocker.patch.object(manager, "_request", return_value=response)
|
||||
|
||||
result = manager.fetch_asset("tenant-1", "asset-1")
|
||||
|
||||
assert result == b"asset-bytes"
|
||||
request_mock.assert_called_once_with(method="GET", path="plugin/tenant-1/asset/asset-1")
|
||||
|
||||
def test_fetch_asset_not_found_raises(self, mocker):
|
||||
manager = PluginAssetManager()
|
||||
mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b""))
|
||||
|
||||
with pytest.raises(ValueError, match="can not found asset asset-1"):
|
||||
manager.fetch_asset("tenant-1", "asset-1")
|
||||
|
||||
def test_extract_asset_success(self, mocker):
|
||||
manager = PluginAssetManager()
|
||||
response = MagicMock(status_code=200, content=b"file-content")
|
||||
request_mock = mocker.patch.object(manager, "_request", return_value=response)
|
||||
|
||||
result = manager.extract_asset("tenant-1", "org/plugin:1", "README.md")
|
||||
|
||||
assert result == b"file-content"
|
||||
request_mock.assert_called_once_with(
|
||||
method="GET",
|
||||
path="plugin/tenant-1/extract-asset/",
|
||||
params={"plugin_unique_identifier": "org/plugin:1", "file_path": "README.md"},
|
||||
)
|
||||
|
||||
def test_extract_asset_not_found_raises(self, mocker):
|
||||
manager = PluginAssetManager()
|
||||
mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b""))
|
||||
|
||||
with pytest.raises(ValueError, match="can not found asset org/plugin:1, 404"):
|
||||
manager.extract_asset("tenant-1", "org/plugin:1", "README.md")
|
||||
137
api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py
Normal file
137
api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.endpoint.exc import EndpointSetupFailedError
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.trigger.errors import (
|
||||
EventIgnoreError,
|
||||
TriggerInvokeError,
|
||||
TriggerPluginInvokeError,
|
||||
TriggerProviderCredentialValidationError,
|
||||
)
|
||||
|
||||
|
||||
class _ResponseStub:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _StreamContext:
|
||||
def __init__(self, lines):
|
||||
self._lines = lines
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def iter_lines(self):
|
||||
return self._lines
|
||||
|
||||
|
||||
class TestBasePluginClientImpl:
|
||||
def test_inject_trace_headers(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch("core.plugin.impl.base.dify_config.ENABLE_OTEL", True)
|
||||
trace_header = "00-abc-xyz-01"
|
||||
mocker.patch("core.helper.trace_id_helper.generate_traceparent_header", return_value=trace_header)
|
||||
|
||||
headers = {}
|
||||
client._inject_trace_headers(headers)
|
||||
|
||||
assert headers["traceparent"] == trace_header
|
||||
|
||||
headers_with_existing = {"TraceParent": "exists"}
|
||||
client._inject_trace_headers(headers_with_existing)
|
||||
assert headers_with_existing["TraceParent"] == "exists"
|
||||
|
||||
def test_stream_request_handles_data_lines_and_dict_payload(self, mocker):
|
||||
client = BasePluginClient()
|
||||
stream_mock = mocker.patch(
|
||||
"core.plugin.impl.base.httpx.stream",
|
||||
return_value=_StreamContext([b"", b"data: hello", "world"]),
|
||||
)
|
||||
|
||||
result = list(client._stream_request("POST", "plugin/tenant/stream", data={"k": "v"}))
|
||||
|
||||
assert result == ["hello", "world"]
|
||||
assert stream_mock.call_args.kwargs["data"] == {"k": "v"}
|
||||
|
||||
def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_request", side_effect=RuntimeError("boom"))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to request plugin daemon"):
|
||||
client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool)
|
||||
|
||||
def test_request_with_plugin_daemon_response_applies_transformer(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_request", return_value=_ResponseStub({"code": 0, "message": "", "data": True}))
|
||||
|
||||
transformed = {}
|
||||
|
||||
def transformer(payload):
|
||||
transformed.update(payload)
|
||||
return payload
|
||||
|
||||
result = client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool, transformer=transformer)
|
||||
|
||||
assert result is True
|
||||
assert transformed == {"code": 0, "message": "", "data": True}
|
||||
|
||||
def test_request_with_plugin_daemon_response_stream_malformed_json_error(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_stream_request", return_value=iter(['{"error":"bad-line"}']))
|
||||
|
||||
with pytest.raises(ValueError, match="bad-line"):
|
||||
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
|
||||
|
||||
def test_request_with_plugin_daemon_response_stream_plugin_daemon_inner_error(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(
|
||||
client, "_stream_request", return_value=iter(['{"code":-500,"message":"not-json","data":null}'])
|
||||
)
|
||||
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
|
||||
assert exc_info.value.message == "not-json"
|
||||
|
||||
def test_request_with_plugin_daemon_response_stream_plugin_daemon_error(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":-1,"message":"err","data":null}']))
|
||||
|
||||
with pytest.raises(ValueError, match="plugin daemon: err, code: -1"):
|
||||
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
|
||||
|
||||
def test_request_with_plugin_daemon_response_stream_empty_data_error(self, mocker):
|
||||
client = BasePluginClient()
|
||||
mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":0,"message":"","data":null}']))
|
||||
|
||||
with pytest.raises(ValueError, match="got empty data"):
|
||||
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error_type", "expected"),
|
||||
[
|
||||
(EndpointSetupFailedError.__name__, EndpointSetupFailedError),
|
||||
(TriggerProviderCredentialValidationError.__name__, TriggerProviderCredentialValidationError),
|
||||
(TriggerPluginInvokeError.__name__, TriggerPluginInvokeError),
|
||||
(TriggerInvokeError.__name__, TriggerInvokeError),
|
||||
(EventIgnoreError.__name__, EventIgnoreError),
|
||||
],
|
||||
)
|
||||
def test_handle_plugin_daemon_error_trigger_branches(self, error_type, expected):
|
||||
client = BasePluginClient()
|
||||
message = json.dumps({"error_type": error_type, "message": "m"})
|
||||
|
||||
with pytest.raises(expected):
|
||||
client._handle_plugin_daemon_error("PluginInvokeError", message)
|
||||
234
api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py
Normal file
234
api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py
Normal file
@ -0,0 +1,234 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
def _datasource_provider(name: str = "provider") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
plugin_id="org/plugin",
|
||||
declaration=SimpleNamespace(
|
||||
identity=SimpleNamespace(name=name),
|
||||
datasources=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPluginDatasourceManager:
|
||||
def test_fetch_datasource_providers(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
provider = _datasource_provider("remote")
|
||||
repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider")
|
||||
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"declaration": {
|
||||
"identity": {"name": "remote"},
|
||||
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"][0]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True}
|
||||
return [provider]
|
||||
|
||||
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = manager.fetch_datasource_providers("tenant-1")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert len(result) == 2
|
||||
assert result[0].plugin_id == "langgenius/file"
|
||||
assert result[1].declaration.identity.name == "org/plugin/remote"
|
||||
assert result[1].declaration.datasources[0].identity.provider == "org/plugin/remote"
|
||||
repack.assert_called_once_with(tenant_id="tenant-1", provider=provider)
|
||||
|
||||
def test_fetch_installed_datasource_providers(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
provider = _datasource_provider("remote")
|
||||
repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider")
|
||||
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"declaration": {
|
||||
"identity": {"name": "remote"},
|
||||
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
transformer(payload)
|
||||
return [provider]
|
||||
|
||||
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = manager.fetch_installed_datasource_providers("tenant-1")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert len(result) == 1
|
||||
assert result[0].declaration.identity.name == "org/plugin/remote"
|
||||
assert result[0].declaration.datasources[0].identity.provider == "org/plugin/remote"
|
||||
repack.assert_called_once_with(tenant_id="tenant-1", provider=provider)
|
||||
|
||||
def test_fetch_datasource_provider_local_and_remote(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
local = manager.fetch_datasource_provider("tenant-1", "langgenius/file/file")
|
||||
assert local.plugin_id == "langgenius/file"
|
||||
|
||||
remote = _datasource_provider("provider")
|
||||
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": {
|
||||
"declaration": {
|
||||
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]
|
||||
}
|
||||
}
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True}
|
||||
return remote
|
||||
|
||||
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = manager.fetch_datasource_provider("tenant-1", "org/plugin/provider")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result.declaration.identity.name == "org/plugin/provider"
|
||||
assert result.declaration.datasources[0].identity.provider == "org/plugin/provider"
|
||||
|
||||
def test_get_website_crawl_streaming(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter(["crawl"])
|
||||
|
||||
assert list(
|
||||
manager.get_website_crawl(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin/provider",
|
||||
"crawl",
|
||||
{"k": "v"},
|
||||
{"url": "https://example.com"},
|
||||
"website",
|
||||
)
|
||||
) == ["crawl"]
|
||||
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_get_online_document_pages_streaming(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter(["pages"])
|
||||
|
||||
assert list(
|
||||
manager.get_online_document_pages(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin/provider",
|
||||
"docs",
|
||||
{"k": "v"},
|
||||
{"workspace": "w1"},
|
||||
"online_document",
|
||||
)
|
||||
) == ["pages"]
|
||||
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_get_online_document_page_content_streaming(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter(["content"])
|
||||
|
||||
assert list(
|
||||
manager.get_online_document_page_content(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin/provider",
|
||||
"docs",
|
||||
{"k": "v"},
|
||||
GetOnlineDocumentPageContentRequest(workspace_id="w", page_id="p", type="doc"),
|
||||
"online_document",
|
||||
)
|
||||
) == ["content"]
|
||||
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_online_drive_browse_files_streaming(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter(["browse"])
|
||||
|
||||
assert list(
|
||||
manager.online_drive_browse_files(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin/provider",
|
||||
"drive",
|
||||
{"k": "v"},
|
||||
OnlineDriveBrowseFilesRequest(prefix="/"),
|
||||
"online_drive",
|
||||
)
|
||||
) == ["browse"]
|
||||
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_online_drive_download_file_streaming(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter(["download"])
|
||||
|
||||
assert list(
|
||||
manager.online_drive_download_file(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin/provider",
|
||||
"drive",
|
||||
{"k": "v"},
|
||||
OnlineDriveDownloadFileRequest(id="file-1"),
|
||||
"online_drive",
|
||||
)
|
||||
) == ["download"]
|
||||
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_validate_provider_credentials_returns_true_when_stream_yields_result(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter([SimpleNamespace(result=True)])
|
||||
|
||||
assert manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is True
|
||||
|
||||
def test_validate_provider_credentials_returns_false_when_stream_empty(self, mocker):
|
||||
manager = PluginDatasourceManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
stream_mock.return_value = iter([])
|
||||
|
||||
assert (
|
||||
manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is False
|
||||
)
|
||||
|
||||
def test_local_file_provider_template(self):
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
payload = manager._get_local_file_datasource_provider()
|
||||
|
||||
assert payload["plugin_id"] == "langgenius/file"
|
||||
assert payload["provider"] == "file"
|
||||
assert payload["declaration"]["provider_type"] == "local_file"
|
||||
@ -0,0 +1,21 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
|
||||
|
||||
class TestPluginDebuggingClient:
|
||||
def test_get_debugging_key(self, mocker):
|
||||
client = PluginDebuggingClient()
|
||||
request_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response",
|
||||
return_value=SimpleNamespace(key="debug-key"),
|
||||
)
|
||||
|
||||
result = client.get_debugging_key("tenant-1")
|
||||
|
||||
assert result == "debug-key"
|
||||
request_mock.assert_called_once()
|
||||
args = request_mock.call_args.args
|
||||
assert args[0] == "POST"
|
||||
assert args[1] == "plugin/tenant-1/debugging/key"
|
||||
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.endpoint import PluginEndpointClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
|
||||
class TestPluginEndpointClientImpl:
|
||||
def test_create_endpoint(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
|
||||
|
||||
result = client.create_endpoint("tenant-1", "user-1", "org/plugin:1", "endpoint-a", {"k": "v"})
|
||||
|
||||
assert result is True
|
||||
assert request_mock.call_count == 1
|
||||
args = request_mock.call_args.args
|
||||
kwargs = request_mock.call_args.kwargs
|
||||
assert args[:3] == ("POST", "plugin/tenant-1/endpoint/setup", bool)
|
||||
assert kwargs["data"]["plugin_unique_identifier"] == "org/plugin:1"
|
||||
|
||||
def test_list_endpoints(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"])
|
||||
|
||||
result = client.list_endpoints("tenant-1", "user-1", 2, 20)
|
||||
|
||||
assert result == ["endpoint"]
|
||||
assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list"
|
||||
assert request_mock.call_args.kwargs["params"] == {"page": 2, "page_size": 20}
|
||||
|
||||
def test_list_endpoints_for_single_plugin(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"])
|
||||
|
||||
result = client.list_endpoints_for_single_plugin("tenant-1", "user-1", "org/plugin", 1, 10)
|
||||
|
||||
assert result == ["endpoint"]
|
||||
assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list/plugin"
|
||||
assert request_mock.call_args.kwargs["params"] == {"plugin_id": "org/plugin", "page": 1, "page_size": 10}
|
||||
|
||||
def test_update_endpoint(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
|
||||
|
||||
result = client.update_endpoint("tenant-1", "user-1", "endpoint-1", "renamed", {"x": 1})
|
||||
|
||||
assert result is True
|
||||
assert request_mock.call_args.args[:3] == ("POST", "plugin/tenant-1/endpoint/update", bool)
|
||||
|
||||
def test_enable_and_disable_endpoint(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
|
||||
|
||||
assert client.enable_endpoint("tenant-1", "user-1", "endpoint-1") is True
|
||||
assert client.disable_endpoint("tenant-1", "user-1", "endpoint-1") is True
|
||||
|
||||
calls = request_mock.call_args_list
|
||||
assert calls[0].args[1] == "plugin/tenant-1/endpoint/enable"
|
||||
assert calls[1].args[1] == "plugin/tenant-1/endpoint/disable"
|
||||
|
||||
def test_delete_endpoint_idempotent_and_re_raise(self, mocker):
|
||||
client = PluginEndpointClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response")
|
||||
|
||||
request_mock.side_effect = PluginDaemonInternalServerError("record not found")
|
||||
assert client.delete_endpoint("tenant-1", "user-1", "endpoint-1") is True
|
||||
|
||||
request_mock.side_effect = PluginDaemonInternalServerError("permission denied")
|
||||
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
|
||||
client.delete_endpoint("tenant-1", "user-1", "endpoint-1")
|
||||
assert "permission denied" in exc_info.value.description
|
||||
41
api/tests/unit_tests/core/plugin/impl/test_exc_impl.py
Normal file
41
api/tests/unit_tests/core/plugin/impl/test_exc_impl.py
Normal file
@ -0,0 +1,41 @@
|
||||
import json
|
||||
|
||||
from core.plugin.impl import exc as exc_module
|
||||
from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError
|
||||
|
||||
|
||||
class TestPluginImplExceptions:
|
||||
def test_plugin_daemon_error_str_contains_request_id(self, mocker):
|
||||
mocker.patch("core.plugin.impl.exc.get_request_id", return_value="req-123")
|
||||
error = PluginDaemonError("bad")
|
||||
|
||||
assert str(error) == "req_id: req-123 PluginDaemonError: bad"
|
||||
|
||||
def test_plugin_invoke_error_with_json_payload(self):
|
||||
err = PluginInvokeError(json.dumps({"error_type": "RateLimit", "message": "too many"}))
|
||||
|
||||
assert err.get_error_type() == "RateLimit"
|
||||
assert err.get_error_message() == "too many"
|
||||
friendly = err.to_user_friendly_error("test-plugin")
|
||||
assert "test-plugin" in friendly
|
||||
assert "RateLimit" in friendly
|
||||
assert "too many" in friendly
|
||||
|
||||
def test_plugin_invoke_error_invalid_json_and_fallback(self, mocker):
|
||||
err = PluginInvokeError("plain text")
|
||||
|
||||
assert err._get_error_object() == {}
|
||||
assert err.get_error_type() == "unknown"
|
||||
assert err.get_error_message() == "unknown"
|
||||
|
||||
mocker.patch.object(PluginInvokeError, "_get_error_object", side_effect=RuntimeError("boom"))
|
||||
err2 = PluginInvokeError("plain text")
|
||||
assert err2.get_error_message() == "plain text"
|
||||
|
||||
def test_plugin_invoke_error_get_error_object_handles_adapter_exception(self, mocker):
|
||||
adapter = mocker.patch.object(exc_module, "TypeAdapter")
|
||||
adapter.return_value.validate_json.side_effect = RuntimeError("invalid")
|
||||
|
||||
err = PluginInvokeError("not-json")
|
||||
|
||||
assert err._get_error_object() == {}
|
||||
490
api/tests/unit_tests/core/plugin/impl/test_model_client.py
Normal file
490
api/tests/unit_tests/core/plugin/impl/test_model_client.py
Normal file
@ -0,0 +1,490 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class TestPluginModelClient:
|
||||
def test_fetch_model_providers(self, mocker):
|
||||
client = PluginModelClient()
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["provider-a"])
|
||||
|
||||
result = client.fetch_model_providers("tenant-1")
|
||||
|
||||
assert result == ["provider-a"]
|
||||
assert request_mock.call_args.args[:2] == (
|
||||
"GET",
|
||||
"plugin/tenant-1/management/models",
|
||||
)
|
||||
assert request_mock.call_args.kwargs["params"] == {"page": 1, "page_size": 256}
|
||||
|
||||
def test_get_model_schema(self, mocker):
|
||||
client = PluginModelClient()
|
||||
schema = SimpleNamespace(name="schema")
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(model_schema=schema)]),
|
||||
)
|
||||
|
||||
result = client.get_model_schema(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model_type="llm",
|
||||
model="gpt-test",
|
||||
credentials={"api_key": "key"},
|
||||
)
|
||||
|
||||
assert result is schema
|
||||
assert stream_mock.call_args.args[:2] == ("POST", "plugin/tenant-1/dispatch/model/schema")
|
||||
|
||||
def test_get_model_schema_empty_stream_returns_none(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
result = client.get_model_schema("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_provider_credentials(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result=True, credentials={"api_key": "new"})]),
|
||||
)
|
||||
credentials = {"api_key": "old"}
|
||||
|
||||
result = client.validate_provider_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert credentials["api_key"] == "new"
|
||||
assert stream_mock.call_args.args[:2] == (
|
||||
"POST",
|
||||
"plugin/tenant-1/dispatch/model/validate_provider_credentials",
|
||||
)
|
||||
|
||||
def test_validate_provider_credentials_without_dict_update(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result=False, credentials="not-a-dict")]),
|
||||
)
|
||||
credentials = {"api_key": "same"}
|
||||
|
||||
result = client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", credentials)
|
||||
|
||||
assert result is False
|
||||
assert credentials == {"api_key": "same"}
|
||||
|
||||
def test_validate_provider_credentials_empty_returns_false(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", {}) is False
|
||||
|
||||
def test_validate_model_credentials(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result=True, credentials={"token": "rotated"})]),
|
||||
)
|
||||
credentials = {"token": "old"}
|
||||
|
||||
result = client.validate_model_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model_type="llm",
|
||||
model="gpt-test",
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert credentials["token"] == "rotated"
|
||||
assert stream_mock.call_args.args[:2] == (
|
||||
"POST",
|
||||
"plugin/tenant-1/dispatch/model/validate_model_credentials",
|
||||
)
|
||||
|
||||
def test_validate_model_credentials_empty_returns_false(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
assert (
|
||||
client.validate_model_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {})
|
||||
is False
|
||||
)
|
||||
|
||||
def test_invoke_llm(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
|
||||
)
|
||||
|
||||
result = list(
|
||||
client.invoke_llm(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="gpt-test",
|
||||
credentials={"api_key": "key"},
|
||||
prompt_messages=[],
|
||||
model_parameters={"temperature": 0.1},
|
||||
tools=[],
|
||||
stop=["STOP"],
|
||||
stream=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["chunk-1"]
|
||||
call_kwargs = stream_mock.call_args.kwargs
|
||||
assert call_kwargs["path"] == "plugin/tenant-1/dispatch/llm/invoke"
|
||||
assert call_kwargs["data"]["data"]["stream"] is False
|
||||
assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1}
|
||||
|
||||
def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker):
|
||||
client = PluginModelClient()
|
||||
|
||||
def _boom():
|
||||
raise PluginDaemonInnerError(code=-500, message="invoke failed")
|
||||
yield # pragma: no cover
|
||||
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom())
|
||||
|
||||
with pytest.raises(ValueError, match="invoke failed-500"):
|
||||
list(
|
||||
client.invoke_llm(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_llm_num_tokens(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(num_tokens=42)]),
|
||||
)
|
||||
|
||||
result = client.get_llm_num_tokens(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model_type="llm",
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert result == 42
|
||||
|
||||
def test_get_llm_num_tokens_empty_returns_zero(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
assert (
|
||||
client.get_llm_num_tokens("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}, [])
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_invoke_text_embedding(self, mocker):
|
||||
client = PluginModelClient()
|
||||
embedding_result = SimpleNamespace(data=[[0.1, 0.2]])
|
||||
mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result])
|
||||
)
|
||||
|
||||
result = client.invoke_text_embedding(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="embedding-a",
|
||||
credentials={},
|
||||
texts=["hello"],
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
assert result is embedding_result
|
||||
|
||||
def test_invoke_text_embedding_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke text embedding"):
|
||||
client.invoke_text_embedding(
|
||||
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["hello"], "x"
|
||||
)
|
||||
|
||||
def test_invoke_multimodal_embedding(self, mocker):
|
||||
client = PluginModelClient()
|
||||
embedding_result = SimpleNamespace(data=[[0.3, 0.4]])
|
||||
mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result])
|
||||
)
|
||||
|
||||
result = client.invoke_multimodal_embedding(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="embedding-a",
|
||||
credentials={},
|
||||
documents=[{"type": "image", "value": "abc"}],
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
assert result is embedding_result
|
||||
|
||||
def test_invoke_multimodal_embedding_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke file embedding"):
|
||||
client.invoke_multimodal_embedding(
|
||||
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, [{"type": "image"}], "x"
|
||||
)
|
||||
|
||||
def test_get_text_embedding_num_tokens(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(num_tokens=[1, 2, 3])]),
|
||||
)
|
||||
|
||||
assert client.get_text_embedding_num_tokens(
|
||||
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"]
|
||||
) == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
]
|
||||
|
||||
def test_get_text_embedding_num_tokens_empty_returns_list(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
assert (
|
||||
client.get_text_embedding_num_tokens(
|
||||
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"]
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
def test_invoke_rerank(self, mocker):
|
||||
client = PluginModelClient()
|
||||
rerank_result = SimpleNamespace(scores=[0.9])
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result]))
|
||||
|
||||
result = client.invoke_rerank(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="rerank-a",
|
||||
credentials={},
|
||||
query="q",
|
||||
docs=["doc-1"],
|
||||
score_threshold=0.2,
|
||||
top_n=5,
|
||||
)
|
||||
|
||||
assert result is rerank_result
|
||||
|
||||
def test_invoke_rerank_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke rerank"):
|
||||
client.invoke_rerank("tenant-1", "user-1", "org/plugin:1", "provider-a", "rerank-a", {}, "q", ["doc-1"])
|
||||
|
||||
def test_invoke_multimodal_rerank(self, mocker):
|
||||
client = PluginModelClient()
|
||||
rerank_result = SimpleNamespace(scores=[0.8])
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result]))
|
||||
|
||||
result = client.invoke_multimodal_rerank(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="rerank-a",
|
||||
credentials={},
|
||||
query={"type": "text", "value": "q"},
|
||||
docs=[{"type": "image", "value": "doc"}],
|
||||
score_threshold=0.1,
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
assert result is rerank_result
|
||||
|
||||
def test_invoke_multimodal_rerank_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke multimodal rerank"):
|
||||
client.invoke_multimodal_rerank(
|
||||
"tenant-1",
|
||||
"user-1",
|
||||
"org/plugin:1",
|
||||
"provider-a",
|
||||
"rerank-a",
|
||||
{},
|
||||
{"type": "text"},
|
||||
[{"type": "image"}],
|
||||
)
|
||||
|
||||
def test_invoke_tts(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result="68656c6c6f"), SimpleNamespace(result="21")]),
|
||||
)
|
||||
|
||||
result = list(
|
||||
client.invoke_tts(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="tts-a",
|
||||
credentials={},
|
||||
content_text="hello",
|
||||
voice="alloy",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == [b"hello", b"!"]
|
||||
|
||||
def test_invoke_tts_wraps_plugin_daemon_inner_error(self, mocker):
|
||||
client = PluginModelClient()
|
||||
|
||||
def _boom():
|
||||
raise PluginDaemonInnerError(code=-400, message="tts error")
|
||||
yield # pragma: no cover
|
||||
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom())
|
||||
|
||||
with pytest.raises(ValueError, match="tts error-400"):
|
||||
list(client.invoke_tts("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}, "hello", "alloy"))
|
||||
|
||||
def test_get_tts_model_voices(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter(
|
||||
[
|
||||
SimpleNamespace(
|
||||
voices=[
|
||||
SimpleNamespace(name="Alloy", value="alloy"),
|
||||
SimpleNamespace(name="Echo", value="echo"),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = client.get_tts_model_voices(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="tts-a",
|
||||
credentials={},
|
||||
language="en",
|
||||
)
|
||||
|
||||
assert result == [{"name": "Alloy", "value": "alloy"}, {"name": "Echo", "value": "echo"}]
|
||||
|
||||
def test_get_tts_model_voices_empty_returns_list(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
assert client.get_tts_model_voices("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}) == []
|
||||
|
||||
def test_invoke_speech_to_text(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result="transcribed text")]),
|
||||
)
|
||||
|
||||
result = client.invoke_speech_to_text(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="stt-a",
|
||||
credentials={},
|
||||
file=io.BytesIO(b"abc"),
|
||||
)
|
||||
|
||||
assert result == "transcribed text"
|
||||
assert stream_mock.call_args.kwargs["data"]["data"]["file"] == "616263"
|
||||
|
||||
def test_invoke_speech_to_text_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke speech to text"):
|
||||
client.invoke_speech_to_text(
|
||||
"tenant-1", "user-1", "org/plugin:1", "provider-a", "stt-a", {}, io.BytesIO(b"abc")
|
||||
)
|
||||
|
||||
def test_invoke_moderation(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(result=True)]),
|
||||
)
|
||||
|
||||
result = client.invoke_moderation(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="moderation-a",
|
||||
credentials={},
|
||||
text="safe text",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert stream_mock.call_args.kwargs["path"] == "plugin/tenant-1/dispatch/moderation/invoke"
|
||||
|
||||
def test_invoke_moderation_empty_raises(self, mocker):
|
||||
client = PluginModelClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to invoke moderation"):
|
||||
client.invoke_moderation("tenant-1", "user-1", "org/plugin:1", "provider-a", "moderation-a", {}, "unsafe")
|
||||
147
api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py
Normal file
147
api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py
Normal file
@ -0,0 +1,147 @@
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug import Request
|
||||
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
|
||||
|
||||
def _build_request(body: bytes = b"payload") -> Request:
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/oauth/callback",
|
||||
"QUERY_STRING": "code=123",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "80",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"HTTP_HOST": "localhost",
|
||||
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||
"HTTP_X_TEST": "yes",
|
||||
}
|
||||
return Request(environ)
|
||||
|
||||
|
||||
class TestOAuthHandler:
|
||||
def test_get_authorization_url(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
stream_mock = mocker.patch.object(
|
||||
handler,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(authorization_url="https://auth.example.com")]),
|
||||
)
|
||||
|
||||
response = handler.get_authorization_url(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={"client_id": "id"},
|
||||
)
|
||||
|
||||
assert response.authorization_url == "https://auth.example.com"
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_get_authorization_url_no_response_raises(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Error getting authorization URL"):
|
||||
handler.get_authorization_url(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={},
|
||||
)
|
||||
|
||||
def test_get_credentials(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
captured_data = {}
|
||||
|
||||
def fake_stream(*args, **kwargs):
|
||||
captured_data.update(kwargs["data"])
|
||||
return iter([SimpleNamespace(credentials={"token": "abc"}, metadata={}, expires_at=1)])
|
||||
|
||||
stream_mock = mocker.patch.object(
|
||||
handler, "_request_with_plugin_daemon_response_stream", side_effect=fake_stream
|
||||
)
|
||||
|
||||
response = handler.get_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={"client_id": "id"},
|
||||
request=_build_request(),
|
||||
)
|
||||
|
||||
assert response.credentials == {"token": "abc"}
|
||||
assert "raw_http_request" in captured_data["data"]
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_get_credentials_no_response_raises(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Error getting credentials"):
|
||||
handler.get_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={},
|
||||
request=_build_request(),
|
||||
)
|
||||
|
||||
def test_refresh_credentials(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
stream_mock = mocker.patch.object(
|
||||
handler,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(credentials={"token": "new"}, metadata={}, expires_at=1)]),
|
||||
)
|
||||
|
||||
response = handler.refresh_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={"client_id": "id"},
|
||||
credentials={"refresh_token": "r"},
|
||||
)
|
||||
|
||||
assert response.credentials == {"token": "new"}
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_refresh_credentials_no_response_raises(self, mocker):
|
||||
handler = OAuthHandler()
|
||||
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="Error refreshing credentials"):
|
||||
handler.refresh_credentials(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin",
|
||||
provider="provider",
|
||||
redirect_uri="https://dify.example.com/callback",
|
||||
system_credentials={},
|
||||
credentials={},
|
||||
)
|
||||
|
||||
def test_convert_request_to_raw_data(self):
|
||||
handler = OAuthHandler()
|
||||
request = _build_request(b"body-data")
|
||||
|
||||
raw = handler._convert_request_to_raw_data(request)
|
||||
|
||||
assert raw.startswith(b"POST /oauth/callback?code=123 HTTP/1.1\r\n")
|
||||
assert b"X-Test: yes\r\n" in raw
|
||||
assert raw.endswith(b"body-data")
|
||||
121
api/tests/unit_tests/core/plugin/impl/test_tool_manager.py
Normal file
121
api/tests/unit_tests/core/plugin/impl/test_tool_manager.py
Normal file
@ -0,0 +1,121 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
|
||||
|
||||
def _tool_provider(name: str = "provider") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
plugin_id="org/plugin",
|
||||
declaration=SimpleNamespace(
|
||||
identity=SimpleNamespace(name=name),
|
||||
tools=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPluginToolManager:
|
||||
def test_fetch_tool_providers(self, mocker):
|
||||
manager = PluginToolManager()
|
||||
provider = _tool_provider("remote")
|
||||
mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True})
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"declaration": {
|
||||
"identity": {"name": "remote"},
|
||||
"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"][0]["declaration"]["tools"][0]["output_schema"] == {"resolved": True}
|
||||
return [provider]
|
||||
|
||||
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = manager.fetch_tool_providers("tenant-1")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result[0].declaration.identity.name == "org/plugin/remote"
|
||||
assert result[0].declaration.tools[0].identity.provider == "org/plugin/remote"
|
||||
|
||||
def test_fetch_tool_provider(self, mocker):
|
||||
manager = PluginToolManager()
|
||||
provider = _tool_provider("provider")
|
||||
mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True})
|
||||
|
||||
def fake_request(method, path, type_, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": {
|
||||
"declaration": {"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]}
|
||||
}
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"]["declaration"]["tools"][0]["output_schema"] == {"resolved": True}
|
||||
return provider
|
||||
|
||||
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = manager.fetch_tool_provider("tenant-1", "org/plugin/provider")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result.declaration.identity.name == "org/plugin/provider"
|
||||
assert result.declaration.tools[0].identity.provider == "org/plugin/provider"
|
||||
|
||||
def test_invoke_merges_chunks(self, mocker):
|
||||
manager = PluginToolManager()
|
||||
stream_mock = mocker.patch.object(
|
||||
manager, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk"])
|
||||
)
|
||||
merge_mock = mocker.patch("core.plugin.impl.tool.merge_blob_chunks", return_value=["merged"])
|
||||
|
||||
result = manager.invoke(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
tool_provider="org/plugin/provider",
|
||||
tool_name="search",
|
||||
credentials={"api_key": "k"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
tool_parameters={"q": "python"},
|
||||
conversation_id="conv-1",
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
assert result == ["merged"]
|
||||
assert merge_mock.call_count == 1
|
||||
assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin"
|
||||
|
||||
def test_validate_credentials_paths(self, mocker):
|
||||
manager = PluginToolManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
|
||||
stream_mock.return_value = iter([SimpleNamespace(result=True)])
|
||||
assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
|
||||
|
||||
stream_mock.return_value = iter([])
|
||||
assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False
|
||||
|
||||
stream_mock.return_value = iter([SimpleNamespace(result=True)])
|
||||
assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
|
||||
|
||||
stream_mock.return_value = iter([])
|
||||
assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False
|
||||
|
||||
def test_get_runtime_parameters_paths(self, mocker):
|
||||
manager = PluginToolManager()
|
||||
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
|
||||
|
||||
stream_mock.return_value = iter([SimpleNamespace(parameters=[{"name": "p"}])])
|
||||
params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search")
|
||||
assert params == [{"name": "p"}]
|
||||
|
||||
stream_mock.return_value = iter([])
|
||||
params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search")
|
||||
assert params == []
|
||||
226
api/tests/unit_tests/core/plugin/impl/test_trigger_client.py
Normal file
226
api/tests/unit_tests/core/plugin/impl/test_trigger_client.py
Normal file
@ -0,0 +1,226 @@
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.trigger import PluginTriggerClient
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
|
||||
def _request() -> Request:
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/events",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "80",
|
||||
"wsgi.input": BytesIO(b"payload"),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": "7",
|
||||
"HTTP_HOST": "localhost",
|
||||
}
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def _subscription() -> Subscription:
|
||||
return Subscription(expires_at=123, endpoint="https://example.com/hook", parameters={"a": 1}, properties={"p": 1})
|
||||
|
||||
|
||||
def _trigger_provider(name: str = "provider") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
plugin_id="org/plugin",
|
||||
declaration=SimpleNamespace(
|
||||
identity=SimpleNamespace(name=name),
|
||||
events=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _subscription_call_kwargs(method_name: str) -> dict:
|
||||
if method_name == "subscribe":
|
||||
return {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"provider": "org/plugin/provider",
|
||||
"credentials": {"token": "x"},
|
||||
"credential_type": CredentialType.API_KEY,
|
||||
"endpoint": "https://example.com/hook",
|
||||
"parameters": {"k": "v"},
|
||||
}
|
||||
|
||||
return {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"provider": "org/plugin/provider",
|
||||
"subscription": _subscription(),
|
||||
"credentials": {"token": "x"},
|
||||
"credential_type": CredentialType.API_KEY,
|
||||
}
|
||||
|
||||
|
||||
class TestPluginTriggerClient:
|
||||
def test_fetch_trigger_providers(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
provider = _trigger_provider("remote")
|
||||
|
||||
def fake_request(*args, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"plugin_id": "org/plugin",
|
||||
"provider": "remote",
|
||||
"declaration": {"events": [{"identity": {"provider": "old"}}]},
|
||||
}
|
||||
]
|
||||
}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"][0]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/remote"
|
||||
return [provider]
|
||||
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = client.fetch_trigger_providers("tenant-1")
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result[0].declaration.identity.name == "org/plugin/remote"
|
||||
assert result[0].declaration.events[0].identity.provider == "org/plugin/remote"
|
||||
|
||||
def test_fetch_trigger_provider(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
provider = _trigger_provider("provider")
|
||||
|
||||
def fake_request(*args, **kwargs):
|
||||
transformer = kwargs["transformer"]
|
||||
payload = {"data": {"declaration": {"events": [{"identity": {"provider": "old"}}]}}}
|
||||
transformed = transformer(payload)
|
||||
assert transformed["data"]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/provider"
|
||||
return provider
|
||||
|
||||
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
|
||||
|
||||
result = client.fetch_trigger_provider("tenant-1", TriggerProviderID("org/plugin/provider"))
|
||||
|
||||
assert request_mock.call_count == 1
|
||||
assert result.declaration.identity.name == "org/plugin/provider"
|
||||
assert result.declaration.events[0].identity.provider == "org/plugin/provider"
|
||||
|
||||
def test_invoke_trigger_event(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(variables={"ok": True}, cancelled=False)]),
|
||||
)
|
||||
|
||||
result = client.invoke_trigger_event(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
provider="org/plugin/provider",
|
||||
event_name="created",
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
request=_request(),
|
||||
parameters={"k": "v"},
|
||||
subscription=_subscription(),
|
||||
payload={"payload": 1},
|
||||
)
|
||||
|
||||
assert result.variables == {"ok": True}
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
def test_invoke_trigger_event_no_response_raises(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
|
||||
with pytest.raises(ValueError, match="No response received from plugin daemon for invoke trigger"):
|
||||
client.invoke_trigger_event(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
provider="org/plugin/provider",
|
||||
event_name="created",
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
request=_request(),
|
||||
parameters={"k": "v"},
|
||||
subscription=_subscription(),
|
||||
payload={"payload": 1},
|
||||
)
|
||||
|
||||
def test_validate_provider_credentials(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
stream_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response_stream")
|
||||
|
||||
stream_mock.return_value = iter([SimpleNamespace(result=True)])
|
||||
assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
|
||||
|
||||
stream_mock.return_value = iter([])
|
||||
with pytest.raises(
|
||||
ValueError, match="No response received from plugin daemon for validate provider credentials"
|
||||
):
|
||||
client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"})
|
||||
|
||||
def test_dispatch_event(self, mocker):
|
||||
client = PluginTriggerClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(user_id="u", events=["e"])]),
|
||||
)
|
||||
|
||||
result = client.dispatch_event(
|
||||
tenant_id="tenant-1",
|
||||
provider="org/plugin/provider",
|
||||
subscription={"id": "sub"},
|
||||
request=_request(),
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result.user_id == "u"
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
stream_mock.return_value = iter([])
|
||||
with pytest.raises(ValueError, match="No response received from plugin daemon for dispatch event"):
|
||||
client.dispatch_event(
|
||||
tenant_id="tenant-1",
|
||||
provider="org/plugin/provider",
|
||||
subscription={"id": "sub"},
|
||||
request=_request(),
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["subscribe", "unsubscribe", "refresh"])
|
||||
def test_subscription_operations_success(self, mocker, method_name):
|
||||
client = PluginTriggerClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client,
|
||||
"_request_with_plugin_daemon_response_stream",
|
||||
return_value=iter([SimpleNamespace(subscription={"id": "sub"})]),
|
||||
)
|
||||
|
||||
method = getattr(client, method_name)
|
||||
result = method(**_subscription_call_kwargs(method_name))
|
||||
|
||||
assert result.subscription == {"id": "sub"}
|
||||
assert stream_mock.call_count == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "expected"),
|
||||
[
|
||||
("subscribe", "No response received from plugin daemon for subscribe"),
|
||||
("unsubscribe", "No response received from plugin daemon for unsubscribe"),
|
||||
("refresh", "No response received from plugin daemon for refresh"),
|
||||
],
|
||||
)
|
||||
def test_subscription_operations_no_response(self, mocker, method_name, expected):
|
||||
client = PluginTriggerClient()
|
||||
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
|
||||
method = getattr(client, method_name)
|
||||
|
||||
with pytest.raises(ValueError, match=expected):
|
||||
method(**_subscription_call_kwargs(method_name))
|
||||
@ -1,72 +1,359 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
class _Chunk(BaseModel):
|
||||
value: int
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
class TestBaseBackwardsInvocation:
|
||||
def test_convert_to_event_stream_with_generator_and_error(self):
|
||||
def _stream():
|
||||
yield _Chunk(value=1)
|
||||
yield {"x": 2}
|
||||
yield "ignored"
|
||||
raise RuntimeError("boom")
|
||||
|
||||
chunks = list(BaseBackwardsInvocation.convert_to_event_stream(_stream()))
|
||||
|
||||
assert len(chunks) == 3
|
||||
first = json.loads(chunks[0].decode())
|
||||
second = json.loads(chunks[1].decode())
|
||||
error = json.loads(chunks[2].decode())
|
||||
assert first["data"]["value"] == 1
|
||||
assert second["data"]["x"] == 2
|
||||
assert error["error"] == "boom"
|
||||
|
||||
def test_convert_to_event_stream_with_non_generator(self):
|
||||
chunks = list(BaseBackwardsInvocation.convert_to_event_stream({"ok": True}))
|
||||
payload = json.loads(chunks[0].decode())
|
||||
assert payload["data"] == {"ok": True}
|
||||
assert payload["error"] == ""
|
||||
|
||||
|
||||
class TestPluginAppBackwardsInvocation:
|
||||
def test_fetch_app_info_workflow_path(self, mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.features_dict = {"feature": "v"}
|
||||
workflow.user_input_form.return_value = [{"name": "foo"}]
|
||||
app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mapper = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
|
||||
return_value={"mapped": True},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1")
|
||||
|
||||
assert result == {"data": {"mapped": True}}
|
||||
mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}])
|
||||
|
||||
def test_fetch_app_info_model_config_path(self, mocker):
|
||||
model_config = MagicMock()
|
||||
model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"}
|
||||
app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
|
||||
return_value={"mapped": True},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1")
|
||||
|
||||
assert result["data"] == {"mapped": True}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "route_method"),
|
||||
[
|
||||
(AppMode.CHAT, "invoke_chat_app"),
|
||||
(AppMode.ADVANCED_CHAT, "invoke_chat_app"),
|
||||
(AppMode.AGENT_CHAT, "invoke_chat_app"),
|
||||
(AppMode.WORKFLOW, "invoke_workflow_app"),
|
||||
(AppMode.COMPLETION, "invoke_completion_app"),
|
||||
],
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
def test_invoke_app_routes_by_mode(self, mocker, mode, route_method):
|
||||
app = MagicMock(mode=mode)
|
||||
user = MagicMock()
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user)
|
||||
route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True})
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"x": 1},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"routed": True}
|
||||
assert route.call_count == 1
|
||||
|
||||
def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker):
|
||||
app = MagicMock(mode=AppMode.WORKFLOW)
|
||||
end_user = MagicMock()
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
get_or_create = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user",
|
||||
return_value=end_user,
|
||||
)
|
||||
route = mocker.patch.object(PluginAppBackwardsInvocation, "invoke_workflow_app", return_value={"ok": True})
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id="app",
|
||||
user_id="",
|
||||
tenant_id="tenant",
|
||||
conversation_id="",
|
||||
query=None,
|
||||
stream=True,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
get_or_create.assert_called_once_with(app)
|
||||
assert route.call_args.args[1] is end_user
|
||||
|
||||
def test_invoke_app_missing_query_for_chat_raises(self, mocker):
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT))
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
|
||||
|
||||
with pytest.raises(ValueError, match="missing query"):
|
||||
PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
query="",
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_invoke_app_unexpected_mode_raises(self, mocker):
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode="other"))
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
|
||||
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
query="q",
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "generator_path"),
|
||||
[
|
||||
(AppMode.AGENT_CHAT, "core.plugin.backwards_invocation.app.AgentChatAppGenerator.generate"),
|
||||
(AppMode.CHAT, "core.plugin.backwards_invocation.app.ChatAppGenerator.generate"),
|
||||
],
|
||||
)
|
||||
def test_invoke_chat_app_agent_and_chat(self, mocker, mode, generator_path):
|
||||
app = MagicMock(mode=mode, workflow=None)
|
||||
spy = mocker.patch(generator_path, return_value={"result": "ok"})
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
result = PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
assert result == {"result": "ok"}
|
||||
assert spy.call_count == 1
|
||||
|
||||
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(self, mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
def test_invoke_workflow_app_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
app.workflow = workflow
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.workflow = workflow
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
result = PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self):
|
||||
app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None)
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_invoke_chat_app_unexpected_mode_raises(self):
|
||||
app = MagicMock(mode="invalid")
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_invoke_workflow_app_injects_pause_state_config(self, mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
def test_invoke_workflow_app_without_workflow_raises(self):
|
||||
app = MagicMock(mode=AppMode.WORKFLOW, workflow=None)
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_invoke_completion_app(self, mocker):
|
||||
spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.CompletionAppGenerator.generate", return_value={"ok": 1}
|
||||
)
|
||||
app = MagicMock(mode=AppMode.COMPLETION)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_completion_app(app, MagicMock(), False, {"x": 1}, [])
|
||||
|
||||
assert result == {"ok": 1}
|
||||
assert spy.call_count == 1
|
||||
|
||||
def test_get_user_returns_end_user(self, mocker):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [MagicMock(id="end-user")]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid")
|
||||
assert user.id == "end-user"
|
||||
|
||||
def test_get_user_falls_back_to_account_user(self, mocker):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [None, MagicMock(id="account-user")]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid")
|
||||
assert user.id == "account-user"
|
||||
|
||||
def test_get_user_raises_when_user_not_found(self, mocker):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [None, None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
|
||||
def test_get_app_returns_app(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
app_obj = MagicMock(id="app")
|
||||
query_chain.first.return_value = app_obj
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
query_chain.first.return_value = None
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
347
api/tests/unit_tests/core/plugin/test_plugin_entities.py
Normal file
347
api/tests/unit_tests/core/plugin/test_plugin_entities.py
Normal file
@ -0,0 +1,347 @@
|
||||
import binascii
|
||||
import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeSpeech2Text,
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeEventResponse,
|
||||
)
|
||||
from core.plugin.utils.http_parser import serialize_response
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestEndpointEntity:
|
||||
def test_endpoint_entity_with_instance_renders_url(self, mocker):
|
||||
mocker.patch("core.plugin.entities.endpoint.dify_config.ENDPOINT_URL_TEMPLATE", "https://dify.test/{hook_id}")
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
|
||||
entity = EndpointEntityWithInstance.model_validate(
|
||||
{
|
||||
"id": "ep-1",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"settings": {},
|
||||
"tenant_id": "tenant",
|
||||
"plugin_id": "org/plugin",
|
||||
"expired_at": now,
|
||||
"name": "my-endpoint",
|
||||
"enabled": True,
|
||||
"hook_id": "hook-123",
|
||||
}
|
||||
)
|
||||
|
||||
assert entity.url == "https://dify.test/hook-123"
|
||||
|
||||
def test_endpoint_entity_with_instance_keeps_existing_url(self):
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
entity = EndpointEntityWithInstance.model_validate(
|
||||
{
|
||||
"id": "ep-1",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"settings": {},
|
||||
"tenant_id": "tenant",
|
||||
"plugin_id": "org/plugin",
|
||||
"expired_at": now,
|
||||
"name": "my-endpoint",
|
||||
"enabled": True,
|
||||
"hook_id": "hook-123",
|
||||
"url": "https://preset.test/hook-123",
|
||||
}
|
||||
)
|
||||
assert entity.url == "https://preset.test/hook-123"
|
||||
|
||||
|
||||
class TestMarketplaceEntities:
|
||||
def test_marketplace_declaration_strips_empty_optional_fields(self):
|
||||
declaration = MarketplacePluginDeclaration.model_validate(
|
||||
{
|
||||
"name": "plugin",
|
||||
"org": "org",
|
||||
"plugin_id": "org/plugin",
|
||||
"icon": "icon.png",
|
||||
"label": {"en_US": "Plugin"},
|
||||
"brief": {"en_US": "Brief"},
|
||||
"resource": {"memory": 256},
|
||||
"endpoint": {},
|
||||
"model": {},
|
||||
"tool": {},
|
||||
"latest_version": "1.0.0",
|
||||
"latest_package_identifier": "org/plugin@1.0.0",
|
||||
"status": "active",
|
||||
"deprecated_reason": "",
|
||||
"alternative_plugin_id": "",
|
||||
}
|
||||
)
|
||||
|
||||
assert declaration.endpoint is None
|
||||
assert declaration.model is None
|
||||
assert declaration.tool is None
|
||||
|
||||
def test_marketplace_snapshot_computed_plugin_id(self):
|
||||
snapshot = MarketplacePluginSnapshot(
|
||||
org="langgenius",
|
||||
name="search",
|
||||
latest_version="1.0.0",
|
||||
latest_package_identifier="langgenius/search@1.0.0",
|
||||
latest_package_url="https://example.com/pkg",
|
||||
)
|
||||
assert snapshot.plugin_id == "langgenius/search"
|
||||
|
||||
|
||||
class TestPluginParameterEntities:
|
||||
def _label(self) -> I18nObject:
|
||||
return I18nObject(en_US="label")
|
||||
|
||||
def test_parameter_option_value_casts_to_string(self):
|
||||
option = PluginParameterOption(value=123, label=self._label())
|
||||
assert option.value == "123"
|
||||
|
||||
def test_plugin_parameter_options_non_list_defaults_to_empty(self):
|
||||
parameter = PluginParameter(name="p", label=self._label(), options="invalid") # type: ignore[arg-type]
|
||||
assert parameter.options == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parameter_type", "expected"),
|
||||
[
|
||||
(PluginParameterType.SECRET_INPUT, "string"),
|
||||
(PluginParameterType.SELECT, "string"),
|
||||
(PluginParameterType.CHECKBOX, "string"),
|
||||
(PluginParameterType.NUMBER, PluginParameterType.NUMBER.value),
|
||||
],
|
||||
)
|
||||
def test_as_normal_type(self, parameter_type, expected):
|
||||
assert as_normal_type(parameter_type) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[(None, ""), (1, "1"), ("abc", "abc")],
|
||||
)
|
||||
def test_cast_parameter_value_string_like(self, value, expected):
|
||||
assert cast_parameter_value(PluginParameterType.STRING, value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
("true", True),
|
||||
("yes", True),
|
||||
("1", True),
|
||||
("false", False),
|
||||
("0", False),
|
||||
("random", True),
|
||||
(1, True),
|
||||
(0, False),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_boolean(self, value, expected):
|
||||
assert cast_parameter_value(PluginParameterType.BOOLEAN, value) is expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(1, 1),
|
||||
(1.5, 1.5),
|
||||
("2", 2),
|
||||
("2.5", 2.5),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_number(self, value, expected):
|
||||
assert cast_parameter_value(PluginParameterType.NUMBER, value) == expected
|
||||
|
||||
def test_cast_parameter_value_file_and_files(self):
|
||||
assert cast_parameter_value(PluginParameterType.FILES, "f1") == ["f1"]
|
||||
assert cast_parameter_value(PluginParameterType.SYSTEM_FILES, ["f1", "f2"]) == ["f1", "f2"]
|
||||
assert cast_parameter_value(PluginParameterType.FILE, ["one"]) == "one"
|
||||
assert cast_parameter_value(PluginParameterType.FILE, "one") == "one"
|
||||
with pytest.raises(ValueError, match="only accepts one file"):
|
||||
cast_parameter_value(PluginParameterType.FILE, ["a", "b"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parameter_type", "value", "expected"),
|
||||
[
|
||||
(PluginParameterType.MODEL_SELECTOR, {"m": "gpt"}, {"m": "gpt"}),
|
||||
(PluginParameterType.APP_SELECTOR, {"app": "a"}, {"app": "a"}),
|
||||
(PluginParameterType.TOOLS_SELECTOR, [], []),
|
||||
(PluginParameterType.ANY, {"k": "v"}, {"k": "v"}),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_selectors_valid(self, parameter_type, value, expected):
|
||||
assert cast_parameter_value(parameter_type, value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parameter_type", "value", "message"),
|
||||
[
|
||||
(PluginParameterType.MODEL_SELECTOR, "bad", "selector must be a dictionary"),
|
||||
(PluginParameterType.APP_SELECTOR, "bad", "selector must be a dictionary"),
|
||||
(PluginParameterType.TOOLS_SELECTOR, "bad", "tools selector must be a list"),
|
||||
(PluginParameterType.ANY, object(), "var selector must be"),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_selectors_invalid(self, parameter_type, value, message):
|
||||
with pytest.raises(ValueError, match=message):
|
||||
cast_parameter_value(parameter_type, value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parameter_type", "value", "expected"),
|
||||
[
|
||||
(PluginParameterType.ARRAY, [1, 2], [1, 2]),
|
||||
(PluginParameterType.ARRAY, "[1, 2]", [1, 2]),
|
||||
(PluginParameterType.OBJECT, {"k": "v"}, {"k": "v"}),
|
||||
(PluginParameterType.OBJECT, '{"a":1}', {"a": 1}),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_array_and_object_valid(self, parameter_type, value, expected):
|
||||
assert cast_parameter_value(parameter_type, value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parameter_type", "value", "expected"),
|
||||
[
|
||||
(PluginParameterType.ARRAY, "bad-json", ["bad-json"]),
|
||||
(PluginParameterType.OBJECT, "bad-json", {}),
|
||||
],
|
||||
)
|
||||
def test_cast_parameter_value_array_and_object_invalid_json_fallback(self, parameter_type, value, expected):
|
||||
assert cast_parameter_value(parameter_type, value) == expected
|
||||
|
||||
def test_cast_parameter_value_default_branch_and_wrapped_exception(self):
|
||||
class _Unknown(StrEnum):
|
||||
CUSTOM = "custom"
|
||||
|
||||
assert cast_parameter_value(_Unknown.CUSTOM, 12) == "12"
|
||||
|
||||
class _BadString:
|
||||
def __str__(self):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"The tool parameter value <.*_BadString object at .* is not in correct type of string\.",
|
||||
):
|
||||
cast_parameter_value(PluginParameterType.STRING, _BadString())
|
||||
|
||||
def test_init_frontend_parameter(self):
|
||||
rule = PluginParameter(
|
||||
name="choice",
|
||||
label=self._label(),
|
||||
required=True,
|
||||
default="a",
|
||||
options=[PluginParameterOption(value="a", label=self._label())],
|
||||
)
|
||||
|
||||
assert init_frontend_parameter(rule, PluginParameterType.SELECT, None) == "a"
|
||||
assert init_frontend_parameter(rule, PluginParameterType.NUMBER, 0) == 0
|
||||
with pytest.raises(ValueError, match="not in options"):
|
||||
init_frontend_parameter(rule, PluginParameterType.SELECT, "b")
|
||||
|
||||
required_rule = PluginParameter(name="required", label=self._label(), required=True, default=None)
|
||||
with pytest.raises(ValueError, match="not found in tool config"):
|
||||
init_frontend_parameter(required_rule, PluginParameterType.STRING, None)
|
||||
|
||||
|
||||
class TestPluginDaemonEntities:
|
||||
def test_credential_type_helpers(self):
|
||||
assert CredentialType.API_KEY.get_name() == "API KEY"
|
||||
assert CredentialType.OAUTH2.get_name() == "AUTH"
|
||||
assert CredentialType.UNAUTHORIZED.get_name() == "UNAUTHORIZED"
|
||||
|
||||
class _FakeCredential:
|
||||
value = "custom-type"
|
||||
|
||||
assert CredentialType.get_name(_FakeCredential()) == "CUSTOM TYPE"
|
||||
assert CredentialType.API_KEY.is_editable() is True
|
||||
assert CredentialType.OAUTH2.is_editable() is False
|
||||
assert CredentialType.API_KEY.is_validate_allowed() is True
|
||||
assert CredentialType.UNAUTHORIZED.is_validate_allowed() is False
|
||||
assert set(CredentialType.values()) == {"api-key", "oauth2", "unauthorized"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("api-key", CredentialType.API_KEY),
|
||||
("api_key", CredentialType.API_KEY),
|
||||
("oauth2", CredentialType.OAUTH2),
|
||||
("oauth", CredentialType.OAUTH2),
|
||||
("unauthorized", CredentialType.UNAUTHORIZED),
|
||||
],
|
||||
)
|
||||
def test_credential_type_of(self, raw, expected):
|
||||
assert CredentialType.of(raw) == expected
|
||||
|
||||
def test_credential_type_of_invalid(self):
|
||||
with pytest.raises(ValueError, match="Invalid credential type"):
|
||||
CredentialType.of("invalid")
|
||||
|
||||
|
||||
class TestPluginRequestEntities:
|
||||
def test_request_invoke_llm_converts_prompt_messages(self):
|
||||
payload = RequestInvokeLLM(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
mode="chat",
|
||||
prompt_messages=[
|
||||
{"role": "user", "content": "u"},
|
||||
{"role": "assistant", "content": "a"},
|
||||
{"role": "system", "content": "s"},
|
||||
{"role": "tool", "content": "t", "tool_call_id": "call-1"},
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(payload.prompt_messages[0], UserPromptMessage)
|
||||
assert isinstance(payload.prompt_messages[1], AssistantPromptMessage)
|
||||
assert isinstance(payload.prompt_messages[2], SystemPromptMessage)
|
||||
assert isinstance(payload.prompt_messages[3], ToolPromptMessage)
|
||||
|
||||
def test_request_invoke_llm_prompt_messages_must_be_list(self):
|
||||
with pytest.raises(ValidationError):
|
||||
RequestInvokeLLM(provider="openai", model="gpt-4", mode="chat", prompt_messages="invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_request_invoke_speech2text_hex_conversion_and_error(self):
|
||||
payload = RequestInvokeSpeech2Text(provider="openai", model="m", file=binascii.hexlify(b"abc").decode())
|
||||
assert payload.file == b"abc"
|
||||
with pytest.raises(ValidationError):
|
||||
RequestInvokeSpeech2Text(provider="openai", model="m", file=b"abc") # type: ignore[arg-type]
|
||||
|
||||
def test_trigger_invoke_event_response_variables_conversion(self):
|
||||
converted = TriggerInvokeEventResponse(variables='{"a": 1}', cancelled=False)
|
||||
assert converted.variables == {"a": 1}
|
||||
passthrough = TriggerInvokeEventResponse(variables={"b": 2}, cancelled=True)
|
||||
assert passthrough.variables == {"b": 2}
|
||||
|
||||
def test_trigger_dispatch_response_convert_response(self):
|
||||
response = Response("ok", status=202, headers={"X-Req": "1"})
|
||||
encoded = binascii.hexlify(serialize_response(response)).decode()
|
||||
parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded)
|
||||
assert parsed.response.status_code == 202
|
||||
assert parsed.response.get_data() == b"ok"
|
||||
with pytest.raises(ValidationError):
|
||||
TriggerDispatchResponse(user_id="u", events=["e"], response="not-hex")
|
||||
|
||||
def test_trigger_dispatch_response_payload_default(self):
|
||||
response = Response("ok", status=200)
|
||||
encoded = binascii.hexlify(serialize_response(response)).decode()
|
||||
parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded)
|
||||
assert parsed.payload == {}
|
||||
@ -4,7 +4,10 @@ import pytest
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
class TestChunkMerger:
|
||||
@ -458,3 +461,89 @@ class TestChunkMerger:
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"FirstSecondThird"
|
||||
|
||||
|
||||
class TestConverter:
|
||||
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
|
||||
file_param = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/file.png",
|
||||
storage_key="",
|
||||
)
|
||||
selector = ToolSelector(
|
||||
provider_id="org/plugin/provider",
|
||||
credential_id=None,
|
||||
tool_name="search",
|
||||
tool_description="search tool",
|
||||
tool_configuration={"k": "v"},
|
||||
tool_parameters={
|
||||
"query": ToolSelector.Parameter(
|
||||
name="query",
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
description="query",
|
||||
default="python",
|
||||
options=[],
|
||||
)
|
||||
},
|
||||
)
|
||||
params = {"file": file_param, "selector": selector, "plain": 123}
|
||||
|
||||
converted = convert_parameters_to_plugin_format(params)
|
||||
|
||||
assert converted["file"]["url"] == "https://example.com/file.png"
|
||||
assert converted["selector"]["provider_id"] == "org/plugin/provider"
|
||||
assert converted["plain"] == 123
|
||||
|
||||
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
|
||||
file_one = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/a.txt",
|
||||
storage_key="",
|
||||
)
|
||||
file_two = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/b.txt",
|
||||
storage_key="",
|
||||
)
|
||||
selector_one = ToolSelector(
|
||||
provider_id="org/plugin/provider",
|
||||
credential_id="cred-1",
|
||||
tool_name="t1",
|
||||
tool_description="tool 1",
|
||||
tool_configuration={},
|
||||
tool_parameters={},
|
||||
)
|
||||
selector_two = ToolSelector(
|
||||
provider_id="org/plugin/provider",
|
||||
credential_id="cred-2",
|
||||
tool_name="t2",
|
||||
tool_description="tool 2",
|
||||
tool_configuration={},
|
||||
tool_parameters={},
|
||||
)
|
||||
|
||||
params = {
|
||||
"files": [file_one, file_two],
|
||||
"selectors": [selector_one, selector_two],
|
||||
"empty_list": [],
|
||||
"mixed_list": [file_one, "raw"],
|
||||
"none_value": None,
|
||||
}
|
||||
|
||||
converted = convert_parameters_to_plugin_format(params)
|
||||
|
||||
assert [item["url"] for item in converted["files"]] == [
|
||||
"https://example.com/a.txt",
|
||||
"https://example.com/b.txt",
|
||||
]
|
||||
assert [item["tool_name"] for item in converted["selectors"]] == ["t1", "t2"]
|
||||
assert converted["empty_list"] == []
|
||||
assert converted["mixed_list"] == [file_one, "raw"]
|
||||
assert converted["none_value"] is None
|
||||
|
||||
@ -381,6 +381,54 @@ class TestEdgeCases:
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == binary_body
|
||||
|
||||
def test_deserialize_request_with_lf_only_newlines(self):
|
||||
raw_data = b"POST /lf-only?x=1 HTTP/1.1\nHost: localhost\nX-Test: yes\n\npayload"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/lf-only"
|
||||
assert request.args.get("x") == "1"
|
||||
assert request.headers.get("X-Test") == "yes"
|
||||
assert request.get_data() == b"payload"
|
||||
|
||||
def test_deserialize_request_without_header_separator_uses_full_input_as_headers(self):
|
||||
raw_data = b"GET /no-separator HTTP/1.1\nHost: localhost\nInvalidHeader\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/no-separator"
|
||||
assert request.headers.get("Host") == "localhost"
|
||||
assert request.headers.get("InvalidHeader") is None
|
||||
|
||||
def test_deserialize_request_empty_payload_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty HTTP request"):
|
||||
deserialize_request(b"")
|
||||
|
||||
def test_deserialize_response_with_lf_only_newlines(self):
|
||||
raw_data = b"HTTP/1.1 202 Accepted\nX-Test: yes\n\nbody"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.headers.get("X-Test") == "yes"
|
||||
assert response.get_data() == b"body"
|
||||
|
||||
def test_deserialize_response_without_header_separator_uses_full_input_as_headers(self):
|
||||
raw_data = b"HTTP/1.1 204 No Content\nX-Test: yes\nInvalidHeader\n"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.headers.get("X-Test") == "yes"
|
||||
assert response.headers.get("InvalidHeader") is None
|
||||
assert response.get_data() == b""
|
||||
|
||||
def test_deserialize_response_empty_payload_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty HTTP response"):
|
||||
deserialize_response(b"")
|
||||
|
||||
|
||||
class TestFileUploads:
|
||||
def test_serialize_request_with_text_file_upload(self):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import Conversation
|
||||
@ -188,3 +191,328 @@ def get_chat_model_args():
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, memory_config, prompt_messages, inputs, context
|
||||
|
||||
|
||||
def test_get_prompt_dispatches_completion_and_chat_and_invalid():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigEntity)
|
||||
completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic")
|
||||
chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")]
|
||||
|
||||
transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")])
|
||||
transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")])
|
||||
|
||||
completion_result = transform.get_prompt(
|
||||
prompt_template=completion_template,
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert completion_result[0].content == "c"
|
||||
|
||||
chat_result = transform.get_prompt(
|
||||
prompt_template=chat_template,
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert chat_result[0].content == "h"
|
||||
|
||||
invalid_result = transform.get_prompt(
|
||||
prompt_template=cast(list, ["not-chat-model-message"]),
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert invalid_result == []
|
||||
|
||||
|
||||
def test_completion_prompt_jinja2_with_files():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
transform = AdvancedPromptTransform()
|
||||
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
|
||||
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"),
|
||||
patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content,
|
||||
):
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_completion_model_prompt_messages(
|
||||
prompt_template=completion_template,
|
||||
inputs={"name": "John"},
|
||||
query="",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0].content, list)
|
||||
assert messages[0].content[0].data == "https://example.com/image.jpg"
|
||||
assert isinstance(messages[0].content[1], TextPromptMessageContent)
|
||||
assert messages[0].content[1].data == "Hi John"
|
||||
|
||||
|
||||
def test_completion_prompt_basic_sets_query_variable():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
transform = AdvancedPromptTransform()
|
||||
template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic")
|
||||
|
||||
messages = transform._get_completion_model_prompt_messages(
|
||||
prompt_template=template,
|
||||
inputs={},
|
||||
query="what?",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert messages[0].content == "Q=what?"
|
||||
|
||||
|
||||
def test_chat_prompt_with_variable_template_and_context():
|
||||
transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)]
|
||||
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"#node.name#": "john"},
|
||||
query=None,
|
||||
files=[],
|
||||
context="context-text",
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], SystemPromptMessage)
|
||||
assert messages[0].content == "sys=john ctx=context-text"
|
||||
|
||||
|
||||
def test_chat_prompt_jinja2_branch_and_invalid_edition():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")]
|
||||
|
||||
with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"):
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"name": "John"},
|
||||
query=None,
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert messages[0].content == "Hello John"
|
||||
|
||||
bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")]
|
||||
with pytest.raises(ValueError, match="Invalid edition type"):
|
||||
transform._get_chat_model_prompt_messages(
|
||||
prompt_template=bad_prompt_template,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_prompt_query_template_and_query_only_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
query_prompt_template="query={{#sys.query#}} ctx={{#context#}}",
|
||||
)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="what",
|
||||
files=[],
|
||||
context="ctx",
|
||||
memory_config=memory_config,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert messages[-1].content == "query={{#sys.query#}} ctx=ctx"
|
||||
|
||||
|
||||
def test_chat_prompt_memory_with_files_and_query():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
memory = MagicMock(spec=TokenBufferMemory)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
transform._append_chat_histories = MagicMock(
|
||||
side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages
|
||||
)
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "q"
|
||||
|
||||
|
||||
def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)]
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_with_last_user,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "u"
|
||||
|
||||
prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)]
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_without_last_user,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert isinstance(messages[-1], UserPromptMessage)
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == ""
|
||||
|
||||
|
||||
def test_chat_prompt_files_with_query_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=[],
|
||||
inputs={},
|
||||
query="query-text",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "query-text"
|
||||
|
||||
|
||||
def test_set_context_query_histories_variable_helpers():
|
||||
transform = AdvancedPromptTransform()
|
||||
parser_context = PromptTemplateParser(template="{{#context#}}")
|
||||
parser_query = PromptTemplateParser(template="{{#query#}}")
|
||||
parser_hist = PromptTemplateParser(template="{{#histories#}}")
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
)
|
||||
|
||||
assert transform._set_context_variable(None, parser_context, {})["#context#"] == ""
|
||||
assert transform._set_query_variable("", parser_query, {})["#query#"] == ""
|
||||
assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x"
|
||||
assert (
|
||||
transform._set_histories_variable(
|
||||
memory=None, # type: ignore[arg-type]
|
||||
memory_config=memory_config,
|
||||
raw_prompt="{{#histories#}}",
|
||||
role_prefix=memory_config.role_prefix, # type: ignore[arg-type]
|
||||
parser=parser_hist,
|
||||
prompt_inputs={},
|
||||
model_config=model_config_mock,
|
||||
)["#histories#"]
|
||||
== ""
|
||||
)
|
||||
|
||||
@ -2,12 +2,14 @@ from uuid import uuid4
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
def __init__(self, id, parent_message_id, answer="answer"):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
self.answer = answer
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_breaks_when_parent_is_none():
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)]
|
||||
|
||||
result = extract_thread_messages(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == id2
|
||||
|
||||
|
||||
def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker):
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id2, id1, answer=""), # newest generated message should be excluded
|
||||
MockMessage(id1, UUID_NIL, answer="ok"),
|
||||
]
|
||||
|
||||
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
|
||||
mock_scalars.return_value.all.return_value = messages
|
||||
|
||||
length = get_thread_messages_length("conversation-1")
|
||||
|
||||
assert length == 1
|
||||
mock_scalars.assert_called_once()
|
||||
|
||||
|
||||
def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker):
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id2, id1, answer="latest-answer"),
|
||||
MockMessage(id1, UUID_NIL, answer="older-answer"),
|
||||
]
|
||||
|
||||
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
|
||||
mock_scalars.return_value.all.return_value = messages
|
||||
|
||||
length = get_thread_messages_length("conversation-2")
|
||||
|
||||
assert length == 2
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
@ -25,3 +30,82 @@ def test_dump_prompt_message():
|
||||
)
|
||||
data = prompt.model_dump()
|
||||
assert data["content"][0].get("url") == example_url
|
||||
|
||||
|
||||
def test_prompt_messages_to_prompt_for_saving_chat_mode():
|
||||
chat_messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="hello "),
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image1.jpg",
|
||||
format="jpg",
|
||||
mime_type="image/jpeg",
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
AudioPromptMessageContent(
|
||||
url="https://example.com/audio1.mp3",
|
||||
format="mp3",
|
||||
mime_type="audio/mpeg",
|
||||
),
|
||||
TextPromptMessageContent(data="world"),
|
||||
]
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content="assistant-text",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"python"}'},
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"),
|
||||
UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages)
|
||||
|
||||
assert len(prompts) == 3
|
||||
assert prompts[0]["role"] == "user"
|
||||
assert prompts[0]["text"] == "hello world"
|
||||
assert prompts[0]["files"][0]["type"] == "image"
|
||||
assert prompts[0]["files"][1]["type"] == "audio"
|
||||
|
||||
assert prompts[1]["role"] == "assistant"
|
||||
assert prompts[1]["text"] == "assistant-text"
|
||||
assert prompts[1]["tool_calls"][0]["function"]["name"] == "search"
|
||||
assert prompts[2]["role"] == "tool"
|
||||
|
||||
|
||||
def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files():
|
||||
completion_message_with_files = UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="first "),
|
||||
TextPromptMessageContent(data="second"),
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image2.jpg",
|
||||
format="jpg",
|
||||
mime_type="image/jpeg",
|
||||
detail=ImagePromptMessageContent.DETAIL.LOW,
|
||||
),
|
||||
]
|
||||
)
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
ModelMode.COMPLETION, [completion_message_with_files]
|
||||
)
|
||||
assert prompts == [
|
||||
{
|
||||
"role": "user",
|
||||
"text": "first second",
|
||||
"files": prompts[0]["files"],
|
||||
}
|
||||
]
|
||||
assert prompts[0]["files"][0]["type"] == "image"
|
||||
|
||||
completion_message_text_only = UserPromptMessage(content="plain text")
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
ModelMode.COMPLETION, [completion_message_text_only]
|
||||
)
|
||||
assert prompts == [{"role": "user", "text": "plain text"}]
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
# from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
|
||||
# from core.app.app_config.entities import ModelConfigEntity
|
||||
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
@ -9,44 +15,217 @@
|
||||
# from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
# def test__calculate_rest_token():
|
||||
# model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
# parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
# parameter_rule_mock.name = "max_tokens"
|
||||
# model_schema_mock.parameter_rules = [parameter_rule_mock]
|
||||
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
|
||||
class TestPromptTransform:
|
||||
def test_resolve_model_runtime_requires_model_config_or_instance(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
# large_language_model_mock.get_num_tokens.return_value = 6
|
||||
with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."):
|
||||
transform._resolve_model_runtime()
|
||||
|
||||
# provider_mock = MagicMock(spec=ProviderEntity)
|
||||
# provider_mock.provider = "openai"
|
||||
def test_resolve_model_runtime_builds_model_instance_from_model_config(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = fake_model_schema
|
||||
fake_model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials=None,
|
||||
parameters=None,
|
||||
stop=None,
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="config-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={"temperature": 0.1},
|
||||
stop=["END"],
|
||||
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
|
||||
)
|
||||
|
||||
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
# provider_configuration_mock.provider = provider_mock
|
||||
# provider_configuration_mock.model_settings = None
|
||||
with patch(
|
||||
"core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance
|
||||
) as model_instance_cls:
|
||||
model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config)
|
||||
|
||||
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
# provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
model_instance_cls.assert_called_once_with(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model,
|
||||
)
|
||||
fake_model_type_instance.get_model_schema.assert_called_once_with(
|
||||
model="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
assert model_instance is fake_model_instance
|
||||
assert model_instance.credentials == {"api_key": "secret"}
|
||||
assert model_instance.parameters == {"temperature": 0.1}
|
||||
assert model_instance.stop == ["END"]
|
||||
assert model_schema is fake_model_schema
|
||||
|
||||
# model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
# model_config_mock.model = "gpt-4"
|
||||
# model_config_mock.credentials = {}
|
||||
# model_config_mock.parameters = {"max_tokens": 50}
|
||||
# model_config_mock.model_schema = model_schema_mock
|
||||
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
def test_resolve_model_runtime_uses_model_config_schema_fallback(self):
|
||||
transform = PromptTransform()
|
||||
fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = None
|
||||
model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={},
|
||||
)
|
||||
model_config = SimpleNamespace(model_schema=fallback_schema)
|
||||
|
||||
# prompt_transform = PromptTransform()
|
||||
resolved_model_instance, resolved_schema = transform._resolve_model_runtime(
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
|
||||
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
assert resolved_model_instance is model_instance
|
||||
assert resolved_schema is fallback_schema
|
||||
|
||||
# # Validate based on the mock configuration and expected logic
|
||||
# expected_rest_tokens = (
|
||||
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
# - model_config_mock.parameters["max_tokens"]
|
||||
# - large_language_model_mock.get_num_tokens.return_value
|
||||
# )
|
||||
# assert rest_tokens == expected_rest_tokens
|
||||
# assert rest_tokens == 6
|
||||
def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = None
|
||||
model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for the provided model instance."):
|
||||
transform._resolve_model_runtime(model_instance=model_instance)
|
||||
|
||||
def test_calculate_rest_token_defaults_when_context_size_missing(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0)
|
||||
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([], model_config=model_config)
|
||||
|
||||
assert rest == 2000
|
||||
|
||||
def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
parameter_rule = SimpleNamespace(name="max_tokens", use_template=None)
|
||||
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95)
|
||||
fake_model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[parameter_rule],
|
||||
)
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[parameter_rule],
|
||||
),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={"max_tokens": 50},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
|
||||
|
||||
assert rest == 0
|
||||
|
||||
def test_calculate_rest_token_supports_use_template_parameter(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens")
|
||||
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20)
|
||||
fake_model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
|
||||
parameter_rules=[parameter_rule],
|
||||
)
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
|
||||
parameter_rules=[parameter_rule],
|
||||
),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
|
||||
|
||||
assert rest == 150
|
||||
|
||||
def test_get_history_messages_from_memory_with_and_without_window(self):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory.get_history_prompt_text.return_value = "history"
|
||||
|
||||
memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3))
|
||||
result = transform._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config_with_window,
|
||||
max_token_limit=100,
|
||||
human_prefix="Human",
|
||||
ai_prefix="Assistant",
|
||||
)
|
||||
|
||||
assert result == "history"
|
||||
memory.get_history_prompt_text.assert_called_with(
|
||||
max_token_limit=100,
|
||||
human_prefix="Human",
|
||||
ai_prefix="Assistant",
|
||||
message_limit=3,
|
||||
)
|
||||
|
||||
memory.reset_mock()
|
||||
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2))
|
||||
transform._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config_no_window,
|
||||
max_token_limit=50,
|
||||
)
|
||||
memory.get_history_prompt_text.assert_called_with(max_token_limit=50)
|
||||
|
||||
def test_get_history_messages_list_from_memory_with_and_without_window(self):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory.get_history_prompt_messages.return_value = ["m1", "m2"]
|
||||
|
||||
memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2))
|
||||
result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120)
|
||||
assert result == ["m1", "m2"]
|
||||
memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2)
|
||||
|
||||
memory.reset_mock()
|
||||
memory.get_history_prompt_messages.return_value = ["only"]
|
||||
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0))
|
||||
result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10)
|
||||
assert result == ["only"]
|
||||
memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None)
|
||||
|
||||
def test_append_chat_histories_extends_prompt_messages(self, monkeypatch):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None))
|
||||
|
||||
monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99)
|
||||
monkeypatch.setattr(
|
||||
transform,
|
||||
"_get_history_messages_list_from_memory",
|
||||
lambda memory, memory_config, max_token_limit: ["h1", "h2"],
|
||||
)
|
||||
|
||||
result = transform._append_chat_histories(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
prompt_messages=["p1"],
|
||||
model_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert result == ["p1", "h1", "h2"]
|
||||
|
||||
@ -1,9 +1,29 @@
|
||||
from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import AppMode, Conversation
|
||||
|
||||
|
||||
@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages():
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_rules.get("stops")
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
|
||||
|
||||
def test_get_prompt_dispatches_chat_and_completion():
|
||||
transform = SimplePromptTransform()
|
||||
model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_chat.mode = "chat"
|
||||
model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_completion.mode = "completion"
|
||||
prompt_entity = SimpleNamespace(simple_prompt_template="hello")
|
||||
|
||||
transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None))
|
||||
transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"]))
|
||||
|
||||
chat_messages, chat_stops = transform.get_prompt(
|
||||
app_mode=AppMode.CHAT,
|
||||
prompt_template_entity=prompt_entity,
|
||||
inputs={"n": 1},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config_chat,
|
||||
)
|
||||
assert chat_messages == ["chat-msg"]
|
||||
assert chat_stops is None
|
||||
|
||||
completion_messages, completion_stops = transform.get_prompt(
|
||||
app_mode=AppMode.CHAT,
|
||||
prompt_template_entity=prompt_entity,
|
||||
inputs={"n": 1},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config_completion,
|
||||
)
|
||||
assert completion_messages == ["completion-msg"]
|
||||
assert completion_stops == ["stop"]
|
||||
|
||||
|
||||
def test_get_prompt_str_and_rules_type_validation_errors():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config.provider = "openai"
|
||||
model_config.model = "gpt-4"
|
||||
valid_prompt_template = SimplePromptTransform().get_prompt_template(
|
||||
AppMode.CHAT, "openai", "gpt-4", "", False, False
|
||||
)["prompt_template"]
|
||||
|
||||
bad_custom_keys = {
|
||||
"prompt_template": valid_prompt_template,
|
||||
"custom_variable_keys": "not-list",
|
||||
"special_variable_keys": [],
|
||||
"prompt_rules": {},
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_custom_keys)
|
||||
with pytest.raises(TypeError, match="custom_variable_keys"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_special_keys = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": "not-list",
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_special_keys)
|
||||
with pytest.raises(TypeError, match="special_variable_keys"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_prompt_template = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": [],
|
||||
"prompt_template": 123,
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_prompt_template)
|
||||
with pytest.raises(TypeError, match="PromptTemplateParser"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_prompt_rules = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": [],
|
||||
"prompt_template": valid_prompt_template,
|
||||
"prompt_rules": "not-dict",
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules)
|
||||
with pytest.raises(TypeError, match="prompt_rules"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
|
||||
def test_chat_model_prompt_messages_uses_prompt_when_query_empty():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {}))
|
||||
transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text"))
|
||||
|
||||
prompt_messages, _ = transform._get_chat_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt="",
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert prompt_messages[0].content == "prompt-text"
|
||||
transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None)
|
||||
|
||||
|
||||
def test_completion_model_prompt_messages_empty_stops_becomes_none():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []}))
|
||||
|
||||
prompt_messages, stops = transform._get_completion_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt="",
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops is None
|
||||
|
||||
|
||||
def test_get_last_user_message_with_files_and_context_files():
|
||||
transform = SimplePromptTransform()
|
||||
file = SimpleNamespace()
|
||||
context_file = SimpleNamespace()
|
||||
|
||||
with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.side_effect = [
|
||||
ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"),
|
||||
ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"),
|
||||
]
|
||||
message = transform._get_last_user_message(
|
||||
prompt="hello",
|
||||
files=[file],
|
||||
context_files=[context_file],
|
||||
image_detail_config=None,
|
||||
)
|
||||
|
||||
assert isinstance(message.content, list)
|
||||
assert message.content[0].data == "https://example.com/a.jpg"
|
||||
assert message.content[1].data == "https://example.com/b.jpg"
|
||||
assert isinstance(message.content[2], TextPromptMessageContent)
|
||||
assert message.content[2].data == "hello"
|
||||
|
||||
|
||||
def test_prompt_file_name_branches():
|
||||
transform = SimplePromptTransform()
|
||||
|
||||
assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat"
|
||||
assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion"
|
||||
assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion"
|
||||
assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat"
|
||||
|
||||
|
||||
def test_advanced_prompt_templates_constants_are_importable():
|
||||
assert isinstance(CONTEXT, str)
|
||||
assert isinstance(BAICHUAN_CONTEXT, str)
|
||||
assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import dataclasses
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper import encrypter
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables.segment_group import SegmentGroup
|
||||
from dify_graph.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
@ -23,6 +26,11 @@ from dify_graph.variables.segments import (
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from dify_graph.variables.utils import (
|
||||
dumps_with_segments,
|
||||
segment_orjson_default,
|
||||
to_selector,
|
||||
)
|
||||
from dify_graph.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
@ -379,3 +387,125 @@ class TestSegmentDumpAndLoad:
|
||||
assert get_segment_discriminator("not_a_dict") is None
|
||||
assert get_segment_discriminator(42) is None
|
||||
assert get_segment_discriminator(object) is None
|
||||
|
||||
|
||||
class TestSegmentAdditionalProperties:
|
||||
def test_base_segment_text_log_markdown_size_and_to_object(self):
|
||||
"""Ensure StringSegment exposes text, log, markdown, size and to_object."""
|
||||
segment = StringSegment(value="hello")
|
||||
|
||||
assert segment.text == "hello"
|
||||
assert segment.log == "hello"
|
||||
assert segment.markdown == "hello"
|
||||
assert segment.size > 0
|
||||
assert segment.to_object() == "hello"
|
||||
|
||||
def test_none_segment_empty_outputs(self):
|
||||
"""Ensure NoneSegment renders empty text, log and markdown."""
|
||||
segment = NoneSegment()
|
||||
|
||||
assert segment.text == ""
|
||||
assert segment.log == ""
|
||||
assert segment.markdown == ""
|
||||
|
||||
def test_object_segment_json_outputs(self):
|
||||
"""Ensure ObjectSegment renders JSON output for text, log and markdown."""
|
||||
segment = ObjectSegment(value={"key": "值", "n": 1})
|
||||
|
||||
assert segment.text == '{"key": "值", "n": 1}'
|
||||
assert segment.log == '{\n "key": "值",\n "n": 1\n}'
|
||||
assert segment.markdown == '{\n "key": "值",\n "n": 1\n}'
|
||||
|
||||
def test_array_segment_text_and_markdown(self):
|
||||
"""Ensure ArrayAnySegment handles empty/non-empty text and markdown rendering."""
|
||||
empty_segment = ArrayAnySegment(value=[])
|
||||
non_empty_segment = ArrayAnySegment(value=[1, "two"])
|
||||
|
||||
assert empty_segment.text == ""
|
||||
assert non_empty_segment.text == "[1, 'two']"
|
||||
assert non_empty_segment.markdown == "- 1\n- two"
|
||||
|
||||
def test_file_segment_properties(self):
|
||||
"""Ensure FileSegment markdown, text and log fields match expected behavior."""
|
||||
file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="doc.txt")
|
||||
segment = FileSegment(value=file)
|
||||
|
||||
assert segment.markdown == "[doc.txt](https://example.com/file.txt)"
|
||||
assert segment.log == ""
|
||||
assert segment.text == ""
|
||||
|
||||
def test_array_string_segment_text_branches(self):
|
||||
"""Ensure ArrayStringSegment text handling for empty and non-empty values."""
|
||||
empty_segment = ArrayStringSegment(value=[])
|
||||
non_empty_segment = ArrayStringSegment(value=["hello", "世界"])
|
||||
|
||||
assert empty_segment.text == ""
|
||||
assert non_empty_segment.text == '["hello", "世界"]'
|
||||
|
||||
def test_array_file_segment_markdown_and_empty_text_log(self):
|
||||
"""Ensure ArrayFileSegment markdown renders links and text/log stay empty."""
|
||||
file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt")
|
||||
file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt")
|
||||
segment = ArrayFileSegment(value=[file1, file2])
|
||||
|
||||
assert segment.markdown == "[a.txt](https://example.com/file.txt)\n[b.txt](https://example.com/file.txt)"
|
||||
assert segment.log == ""
|
||||
assert segment.text == ""
|
||||
|
||||
|
||||
class TestSegmentGroupAdditional:
|
||||
def test_segment_group_markdown_and_to_object(self):
|
||||
group = SegmentGroup(value=[StringSegment(value="A"), NoneSegment(), StringSegment(value="B")])
|
||||
|
||||
assert group.markdown == "AB"
|
||||
assert group.to_object() == ["A", None, "B"]
|
||||
|
||||
|
||||
class TestSegmentUtils:
|
||||
def test_to_selector_without_paths(self):
|
||||
assert to_selector("node-1", "output") == ["node-1", "output"]
|
||||
|
||||
def test_to_selector_with_paths(self):
|
||||
assert to_selector("node-1", "output", ("a", "b")) == ["node-1", "output", "a", "b"]
|
||||
|
||||
def test_array_file_segment_serialization(self):
|
||||
file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt")
|
||||
file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt")
|
||||
|
||||
result = segment_orjson_default(ArrayFileSegment(value=[file1, file2]))
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["filename"] == "a.txt"
|
||||
assert result[1]["filename"] == "b.txt"
|
||||
|
||||
def test_file_segment_serialization(self):
|
||||
file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="single.txt")
|
||||
|
||||
result = segment_orjson_default(FileSegment(value=file))
|
||||
|
||||
assert result["filename"] == "single.txt"
|
||||
assert result["remote_url"] == "https://example.com/file.txt"
|
||||
|
||||
def test_segment_group_and_segment_serialization(self):
|
||||
group = SegmentGroup(value=[StringSegment(value="a"), StringSegment(value="b")])
|
||||
|
||||
assert segment_orjson_default(group) == ["a", "b"]
|
||||
assert segment_orjson_default(StringSegment(value="value")) == "value"
|
||||
|
||||
def test_segment_orjson_default_unsupported_type(self):
|
||||
with pytest.raises(TypeError, match="not JSON serializable"):
|
||||
segment_orjson_default(object())
|
||||
|
||||
def test_dumps_with_segments(self):
|
||||
data = {
|
||||
"segment": StringSegment(value="hello"),
|
||||
"group": SegmentGroup(value=[StringSegment(value="x"), StringSegment(value="y")]),
|
||||
1: "numeric-key",
|
||||
}
|
||||
|
||||
dumped = dumps_with_segments(data)
|
||||
loaded = orjson.loads(dumped)
|
||||
|
||||
assert loaded["segment"] == "hello"
|
||||
assert loaded["group"] == ["x", "y"]
|
||||
assert loaded["1"] == "numeric-key"
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.variables.segment_group import SegmentGroup
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from dify_graph.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
@ -69,22 +71,36 @@ class TestSegmentTypeIsValidArrayValidation:
|
||||
"""
|
||||
|
||||
def test_array_validation_all_success(self):
|
||||
# Arrange
|
||||
value = ["hello", "world", "foo"]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
# Act
|
||||
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
# Assert
|
||||
assert is_valid
|
||||
|
||||
def test_array_validation_all_fail(self):
|
||||
# Arrange
|
||||
value = ["hello", 123, "world"]
|
||||
# Should return False, since 123 is not a string
|
||||
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
# Act
|
||||
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
# Assert
|
||||
assert not is_valid
|
||||
|
||||
def test_array_validation_first(self):
|
||||
# Arrange
|
||||
value = ["hello", 123, None]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
|
||||
# Act
|
||||
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
|
||||
# Assert
|
||||
assert is_valid
|
||||
|
||||
def test_array_validation_none(self):
|
||||
# Arrange
|
||||
value = [1, 2, 3]
|
||||
# validation is None, skip
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||
# Act
|
||||
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||
# Assert
|
||||
assert is_valid
|
||||
|
||||
|
||||
class TestSegmentTypeGetZeroValue:
|
||||
@ -163,3 +179,62 @@ class TestSegmentTypeGetZeroValue:
|
||||
for seg_type in unsupported_types:
|
||||
with pytest.raises(ValueError, match="unsupported variable type"):
|
||||
SegmentType.get_zero_value(seg_type)
|
||||
|
||||
|
||||
class TestSegmentTypeInferSegmentType:
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
([], SegmentType.ARRAY_NUMBER),
|
||||
([1, 2, 3], SegmentType.ARRAY_NUMBER),
|
||||
([1, 2.5], SegmentType.ARRAY_NUMBER),
|
||||
(["a", "b"], SegmentType.ARRAY_STRING),
|
||||
([{"k": "v"}], SegmentType.ARRAY_OBJECT),
|
||||
([None], SegmentType.ARRAY_ANY),
|
||||
([True, False], SegmentType.ARRAY_BOOLEAN),
|
||||
([[1], [2]], SegmentType.ARRAY_ANY),
|
||||
([1, "a"], SegmentType.ARRAY_ANY),
|
||||
(None, SegmentType.NONE),
|
||||
(True, SegmentType.BOOLEAN),
|
||||
(1, SegmentType.INTEGER),
|
||||
(1.2, SegmentType.FLOAT),
|
||||
("abc", SegmentType.STRING),
|
||||
({"k": "v"}, SegmentType.OBJECT),
|
||||
],
|
||||
)
|
||||
def test_infer_segment_type_supported_values(self, value, expected):
|
||||
assert SegmentType.infer_segment_type(value) == expected
|
||||
|
||||
|
||||
class TestSegmentTypeAdditionalMethods:
|
||||
def test_cast_value_for_bool_number_and_array_number(self):
|
||||
assert SegmentType.cast_value(True, SegmentType.INTEGER) == 1
|
||||
assert SegmentType.cast_value(False, SegmentType.NUMBER) == 0
|
||||
assert SegmentType.cast_value([True, False], SegmentType.ARRAY_NUMBER) == [1, 0]
|
||||
|
||||
mixed = [True, 1]
|
||||
assert SegmentType.cast_value(mixed, SegmentType.ARRAY_NUMBER) is mixed
|
||||
assert SegmentType.cast_value("x", SegmentType.STRING) == "x"
|
||||
|
||||
def test_exposed_type_and_element_type(self):
|
||||
assert SegmentType.INTEGER.exposed_type() == SegmentType.NUMBER
|
||||
assert SegmentType.FLOAT.exposed_type() == SegmentType.NUMBER
|
||||
assert SegmentType.STRING.exposed_type() == SegmentType.STRING
|
||||
|
||||
assert SegmentType.ARRAY_STRING.element_type() == SegmentType.STRING
|
||||
assert SegmentType.ARRAY_ANY.element_type() is None
|
||||
|
||||
with pytest.raises(ValueError, match="element_type is only supported by array type"):
|
||||
SegmentType.STRING.element_type()
|
||||
|
||||
def test_group_validation_for_segment_group_and_list(self):
|
||||
valid_group = SegmentGroup(value=[StringSegment(value="a")])
|
||||
assert SegmentType.GROUP.is_valid(valid_group) is True
|
||||
assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True
|
||||
assert SegmentType.GROUP.is_valid(["not-segment"]) is False
|
||||
|
||||
def test_unreachable_assertion_branch(self, monkeypatch):
|
||||
monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False)
|
||||
|
||||
with pytest.raises(AssertionError, match="unreachable"):
|
||||
SegmentType.ARRAY_STRING.is_valid(["a"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user