From 07e19c074878c3e1f7e76c4c4dff9ea49b249921 Mon Sep 17 00:00:00 2001 From: Rajat Agarwal Date: Thu, 12 Mar 2026 08:59:02 +0530 Subject: [PATCH] test: unit test cases for core.variables, core.plugin, core.prompt module (#32637) --- api/core/plugin/entities/parameters.py | 2 +- .../unit_tests/core/plugin/impl/__init__.py | 0 .../core/plugin/impl/test_agent_client.py | 91 ++++ .../core/plugin/impl/test_asset_manager.py | 45 ++ .../core/plugin/impl/test_base_client_impl.py | 137 +++++ .../plugin/impl/test_datasource_manager.py | 234 +++++++++ .../core/plugin/impl/test_debugging_client.py | 21 + .../plugin/impl/test_endpoint_client_impl.py | 71 +++ .../core/plugin/impl/test_exc_impl.py | 41 ++ .../core/plugin/impl/test_model_client.py | 490 ++++++++++++++++++ .../core/plugin/impl/test_oauth_handler.py | 147 ++++++ .../core/plugin/impl/test_tool_manager.py | 121 +++++ .../core/plugin/impl/test_trigger_client.py | 226 ++++++++ .../plugin/test_backwards_invocation_app.py | 391 ++++++++++++-- .../core/plugin/test_plugin_entities.py | 347 +++++++++++++ .../core/plugin/utils/test_chunk_merger.py | 91 +++- .../core/plugin/utils/test_http_parser.py | 48 ++ .../prompt/test_advanced_prompt_transform.py | 328 ++++++++++++ .../prompt/test_extract_thread_messages.py | 45 +- .../core/prompt/test_prompt_message.py | 84 +++ .../core/prompt/test_prompt_transform.py | 247 +++++++-- .../prompt/test_simple_prompt_transform.py | 199 ++++++- .../unit_tests/core/variables/test_segment.py | 130 +++++ .../core/variables/test_segment_type.py | 87 +++- 24 files changed, 3526 insertions(+), 97 deletions(-) create mode 100644 api/tests/unit_tests/core/plugin/impl/__init__.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_agent_client.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_asset_manager.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_debugging_client.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_exc_impl.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_model_client.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_tool_manager.py create mode 100644 api/tests/unit_tests/core/plugin/impl/test_trigger_client.py create mode 100644 api/tests/unit_tests/core/plugin/test_plugin_entities.py diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index bfa662b9f6..ce5813a294 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /): except ValueError: raise except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") + raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.") def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): diff --git a/api/tests/unit_tests/core/plugin/impl/__init__.py b/api/tests/unit_tests/core/plugin/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/plugin/impl/test_agent_client.py b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py new file mode 100644 index 0000000000..1537ffacf5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace + +from core.plugin.entities.request import PluginInvokeContext +from core.plugin.impl.agent import PluginAgentClient + + +def _agent_provider(name: str = "agent") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + strategies=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginAgentClient: + def test_fetch_agent_strategy_providers(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("remote") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "strategies": [{"identity": {"provider": "old"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["strategies"][0]["identity"]["provider"] == "remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.strategies[0].identity.provider == "org/plugin/remote" + + def test_fetch_agent_strategy_provider(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("provider") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + assert transformer({"data": None}) == {"data": None} + payload = {"data": {"declaration": {"strategies": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["strategies"][0]["identity"]["provider"] == "provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.strategies[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks_and_passes_context(self, mocker): + client = PluginAgentClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["raw"]) + ) + merge_mock = mocker.patch("core.plugin.impl.agent.merge_blob_chunks", return_value=["merged"]) + context = PluginInvokeContext() + + result = client.invoke( + tenant_id="tenant-1", + user_id="user-1", + agent_provider="org/plugin/provider", + agent_strategy="router", + agent_params={"k": "v"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + context=context, + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + payload = stream_mock.call_args.kwargs["data"] + assert payload["data"]["agent_strategy_provider"] == "provider" + assert payload["context"] == context.model_dump() + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" diff --git a/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py new file mode 100644 index 0000000000..5f564062d5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +import pytest + +from core.plugin.impl.asset import PluginAssetManager + + +class TestPluginAssetManager: + def test_fetch_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"asset-bytes") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.fetch_asset("tenant-1", "asset-1") + + assert result == b"asset-bytes" + request_mock.assert_called_once_with(method="GET", path="plugin/tenant-1/asset/asset-1") + + def test_fetch_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset asset-1"): + manager.fetch_asset("tenant-1", "asset-1") + + def test_extract_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"file-content") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.extract_asset("tenant-1", "org/plugin:1", "README.md") + + assert result == b"file-content" + request_mock.assert_called_once_with( + method="GET", + path="plugin/tenant-1/extract-asset/", + params={"plugin_unique_identifier": "org/plugin:1", "file_path": "README.md"}, + ) + + def test_extract_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset org/plugin:1, 404"): + manager.extract_asset("tenant-1", "org/plugin:1", "README.md") diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py new file mode 100644 index 0000000000..c216906d68 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -0,0 +1,137 @@ +import json + +import pytest + +from core.plugin.endpoint.exc import EndpointSetupFailedError +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.base import BasePluginClient +from core.trigger.errors import ( + EventIgnoreError, + TriggerInvokeError, + TriggerPluginInvokeError, + TriggerProviderCredentialValidationError, +) + + +class _ResponseStub: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _StreamContext: + def __init__(self, lines): + self._lines = lines + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def iter_lines(self): + return self._lines + + +class TestBasePluginClientImpl: + def test_inject_trace_headers(self, mocker): + client = BasePluginClient() + mocker.patch("core.plugin.impl.base.dify_config.ENABLE_OTEL", True) + trace_header = "00-abc-xyz-01" + mocker.patch("core.helper.trace_id_helper.generate_traceparent_header", return_value=trace_header) + + headers = {} + client._inject_trace_headers(headers) + + assert headers["traceparent"] == trace_header + + headers_with_existing = {"TraceParent": "exists"} + client._inject_trace_headers(headers_with_existing) + assert headers_with_existing["TraceParent"] == "exists" + + def test_stream_request_handles_data_lines_and_dict_payload(self, mocker): + client = BasePluginClient() + stream_mock = mocker.patch( + "core.plugin.impl.base.httpx.stream", + return_value=_StreamContext([b"", b"data: hello", "world"]), + ) + + result = list(client._stream_request("POST", "plugin/tenant/stream", data={"k": "v"})) + + assert result == ["hello", "world"] + assert stream_mock.call_args.kwargs["data"] == {"k": "v"} + + def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", side_effect=RuntimeError("boom")) + + with pytest.raises(ValueError, match="Failed to request plugin daemon"): + client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool) + + def test_request_with_plugin_daemon_response_applies_transformer(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", return_value=_ResponseStub({"code": 0, "message": "", "data": True})) + + transformed = {} + + def transformer(payload): + transformed.update(payload) + return payload + + result = client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool, transformer=transformer) + + assert result is True + assert transformed == {"code": 0, "message": "", "data": True} + + def test_request_with_plugin_daemon_response_stream_malformed_json_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"error":"bad-line"}'])) + + with pytest.raises(ValueError, match="bad-line"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_inner_error(self, mocker): + client = BasePluginClient() + mocker.patch.object( + client, "_stream_request", return_value=iter(['{"code":-500,"message":"not-json","data":null}']) + ) + + with pytest.raises(PluginDaemonInnerError) as exc_info: + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + assert exc_info.value.message == "not-json" + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":-1,"message":"err","data":null}'])) + + with pytest.raises(ValueError, match="plugin daemon: err, code: -1"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_empty_data_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":0,"message":"","data":null}'])) + + with pytest.raises(ValueError, match="got empty data"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + @pytest.mark.parametrize( + ("error_type", "expected"), + [ + (EndpointSetupFailedError.__name__, EndpointSetupFailedError), + (TriggerProviderCredentialValidationError.__name__, TriggerProviderCredentialValidationError), + (TriggerPluginInvokeError.__name__, TriggerPluginInvokeError), + (TriggerInvokeError.__name__, TriggerInvokeError), + (EventIgnoreError.__name__, EventIgnoreError), + ], + ) + def test_handle_plugin_daemon_error_trigger_branches(self, error_type, expected): + client = BasePluginClient() + message = json.dumps({"error_type": error_type, "message": "m"}) + + with pytest.raises(expected): + client._handle_plugin_daemon_error("PluginInvokeError", message) diff --git a/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py new file mode 100644 index 0000000000..4c5987d759 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace + +from core.datasource.entities.datasource_entities import ( + GetOnlineDocumentPageContentRequest, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +def _datasource_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + datasources=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginDatasourceManager: + def test_fetch_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 2 + assert result[0].plugin_id == "langgenius/file" + assert result[1].declaration.identity.name == "org/plugin/remote" + assert result[1].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_installed_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformer(payload) + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_installed_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_datasource_provider_local_and_remote(self, mocker): + manager = PluginDatasourceManager() + + local = manager.fetch_datasource_provider("tenant-1", "langgenius/file/file") + assert local.plugin_id == "langgenius/file" + + remote = _datasource_provider("provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": { + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}] + } + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return remote + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.datasources[0].identity.provider == "org/plugin/provider" + + def test_get_website_crawl_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["crawl"]) + + assert list( + manager.get_website_crawl( + "tenant-1", + "user-1", + "org/plugin/provider", + "crawl", + {"k": "v"}, + {"url": "https://example.com"}, + "website", + ) + ) == ["crawl"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_pages_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["pages"]) + + assert list( + manager.get_online_document_pages( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + {"workspace": "w1"}, + "online_document", + ) + ) == ["pages"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_page_content_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["content"]) + + assert list( + manager.get_online_document_page_content( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + GetOnlineDocumentPageContentRequest(workspace_id="w", page_id="p", type="doc"), + "online_document", + ) + ) == ["content"] + + assert stream_mock.call_count == 1 + + def test_online_drive_browse_files_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["browse"]) + + assert list( + manager.online_drive_browse_files( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveBrowseFilesRequest(prefix="/"), + "online_drive", + ) + ) == ["browse"] + + assert stream_mock.call_count == 1 + + def test_online_drive_download_file_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["download"]) + + assert list( + manager.online_drive_download_file( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveDownloadFileRequest(id="file-1"), + "online_drive", + ) + ) == ["download"] + + assert stream_mock.call_count == 1 + + def test_validate_provider_credentials_returns_true_when_stream_yields_result(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + + assert manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is True + + def test_validate_provider_credentials_returns_false_when_stream_empty(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([]) + + assert ( + manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is False + ) + + def test_local_file_provider_template(self): + manager = PluginDatasourceManager() + + payload = manager._get_local_file_datasource_provider() + + assert payload["plugin_id"] == "langgenius/file" + assert payload["provider"] == "file" + assert payload["declaration"]["provider_type"] == "local_file" diff --git a/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py new file mode 100644 index 0000000000..c80785aee0 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py @@ -0,0 +1,21 @@ +from types import SimpleNamespace + +from core.plugin.impl.debugging import PluginDebuggingClient + + +class TestPluginDebuggingClient: + def test_get_debugging_key(self, mocker): + client = PluginDebuggingClient() + request_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response", + return_value=SimpleNamespace(key="debug-key"), + ) + + result = client.get_debugging_key("tenant-1") + + assert result == "debug-key" + request_mock.assert_called_once() + args = request_mock.call_args.args + assert args[0] == "POST" + assert args[1] == "plugin/tenant-1/debugging/key" diff --git a/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py new file mode 100644 index 0000000000..4cf657a050 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py @@ -0,0 +1,71 @@ +import pytest + +from core.plugin.impl.endpoint import PluginEndpointClient +from core.plugin.impl.exc import PluginDaemonInternalServerError + + +class TestPluginEndpointClientImpl: + def test_create_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.create_endpoint("tenant-1", "user-1", "org/plugin:1", "endpoint-a", {"k": "v"}) + + assert result is True + assert request_mock.call_count == 1 + args = request_mock.call_args.args + kwargs = request_mock.call_args.kwargs + assert args[:3] == ("POST", "plugin/tenant-1/endpoint/setup", bool) + assert kwargs["data"]["plugin_unique_identifier"] == "org/plugin:1" + + def test_list_endpoints(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints("tenant-1", "user-1", 2, 20) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list" + assert request_mock.call_args.kwargs["params"] == {"page": 2, "page_size": 20} + + def test_list_endpoints_for_single_plugin(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints_for_single_plugin("tenant-1", "user-1", "org/plugin", 1, 10) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list/plugin" + assert request_mock.call_args.kwargs["params"] == {"plugin_id": "org/plugin", "page": 1, "page_size": 10} + + def test_update_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.update_endpoint("tenant-1", "user-1", "endpoint-1", "renamed", {"x": 1}) + + assert result is True + assert request_mock.call_args.args[:3] == ("POST", "plugin/tenant-1/endpoint/update", bool) + + def test_enable_and_disable_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + assert client.enable_endpoint("tenant-1", "user-1", "endpoint-1") is True + assert client.disable_endpoint("tenant-1", "user-1", "endpoint-1") is True + + calls = request_mock.call_args_list + assert calls[0].args[1] == "plugin/tenant-1/endpoint/enable" + assert calls[1].args[1] == "plugin/tenant-1/endpoint/disable" + + def test_delete_endpoint_idempotent_and_re_raise(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response") + + request_mock.side_effect = PluginDaemonInternalServerError("record not found") + assert client.delete_endpoint("tenant-1", "user-1", "endpoint-1") is True + + request_mock.side_effect = PluginDaemonInternalServerError("permission denied") + with pytest.raises(PluginDaemonInternalServerError) as exc_info: + client.delete_endpoint("tenant-1", "user-1", "endpoint-1") + assert "permission denied" in exc_info.value.description diff --git a/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py new file mode 100644 index 0000000000..8c6f1c6b7f --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py @@ -0,0 +1,41 @@ +import json + +from core.plugin.impl import exc as exc_module +from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError + + +class TestPluginImplExceptions: + def test_plugin_daemon_error_str_contains_request_id(self, mocker): + mocker.patch("core.plugin.impl.exc.get_request_id", return_value="req-123") + error = PluginDaemonError("bad") + + assert str(error) == "req_id: req-123 PluginDaemonError: bad" + + def test_plugin_invoke_error_with_json_payload(self): + err = PluginInvokeError(json.dumps({"error_type": "RateLimit", "message": "too many"})) + + assert err.get_error_type() == "RateLimit" + assert err.get_error_message() == "too many" + friendly = err.to_user_friendly_error("test-plugin") + assert "test-plugin" in friendly + assert "RateLimit" in friendly + assert "too many" in friendly + + def test_plugin_invoke_error_invalid_json_and_fallback(self, mocker): + err = PluginInvokeError("plain text") + + assert err._get_error_object() == {} + assert err.get_error_type() == "unknown" + assert err.get_error_message() == "unknown" + + mocker.patch.object(PluginInvokeError, "_get_error_object", side_effect=RuntimeError("boom")) + err2 = PluginInvokeError("plain text") + assert err2.get_error_message() == "plain text" + + def test_plugin_invoke_error_get_error_object_handles_adapter_exception(self, mocker): + adapter = mocker.patch.object(exc_module, "TypeAdapter") + adapter.return_value.validate_json.side_effect = RuntimeError("invalid") + + err = PluginInvokeError("not-json") + + assert err._get_error_object() == {} diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_client.py b/api/tests/unit_tests/core/plugin/impl/test_model_client.py new file mode 100644 index 0000000000..bcbebbb38b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_client.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.model import PluginModelClient + + +class TestPluginModelClient: + def test_fetch_model_providers(self, mocker): + client = PluginModelClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["provider-a"]) + + result = client.fetch_model_providers("tenant-1") + + assert result == ["provider-a"] + assert request_mock.call_args.args[:2] == ( + "GET", + "plugin/tenant-1/management/models", + ) + assert request_mock.call_args.kwargs["params"] == {"page": 1, "page_size": 256} + + def test_get_model_schema(self, mocker): + client = PluginModelClient() + schema = SimpleNamespace(name="schema") + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(model_schema=schema)]), + ) + + result = client.get_model_schema( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={"api_key": "key"}, + ) + + assert result is schema + assert stream_mock.call_args.args[:2] == ("POST", "plugin/tenant-1/dispatch/model/schema") + + def test_get_model_schema_empty_stream_returns_none(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + result = client.get_model_schema("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + + assert result is None + + def test_validate_provider_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"api_key": "new"})]), + ) + credentials = {"api_key": "old"} + + result = client.validate_provider_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + credentials=credentials, + ) + + assert result is True + assert credentials["api_key"] == "new" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_provider_credentials", + ) + + def test_validate_provider_credentials_without_dict_update(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=False, credentials="not-a-dict")]), + ) + credentials = {"api_key": "same"} + + result = client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", credentials) + + assert result is False + assert credentials == {"api_key": "same"} + + def test_validate_provider_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", {}) is False + + def test_validate_model_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"token": "rotated"})]), + ) + credentials = {"token": "old"} + + result = client.validate_model_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials=credentials, + ) + + assert result is True + assert credentials["token"] == "rotated" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_model_credentials", + ) + + def test_validate_model_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.validate_model_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + is False + ) + + def test_invoke_llm(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"]) + ) + + result = list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={"api_key": "key"}, + prompt_messages=[], + model_parameters={"temperature": 0.1}, + tools=[], + stop=["STOP"], + stream=False, + ) + ) + + assert result == ["chunk-1"] + call_kwargs = stream_mock.call_args.kwargs + assert call_kwargs["path"] == "plugin/tenant-1/dispatch/llm/invoke" + assert call_kwargs["data"]["data"]["stream"] is False + assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1} + + def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-500, message="invoke failed") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="invoke failed-500"): + list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={}, + prompt_messages=[], + ) + ) + + def test_get_llm_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=42)]), + ) + + result = client.get_llm_num_tokens( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={}, + prompt_messages=[], + tools=[], + ) + + assert result == 42 + + def test_get_llm_num_tokens_empty_returns_zero(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_llm_num_tokens("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}, []) + == 0 + ) + + def test_invoke_text_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.1, 0.2]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_text_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + texts=["hello"], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_text_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke text embedding"): + client.invoke_text_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["hello"], "x" + ) + + def test_invoke_multimodal_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.3, 0.4]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_multimodal_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + documents=[{"type": "image", "value": "abc"}], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_multimodal_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke file embedding"): + client.invoke_multimodal_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, [{"type": "image"}], "x" + ) + + def test_get_text_embedding_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=[1, 2, 3])]), + ) + + assert client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) == [ + 1, + 2, + 3, + ] + + def test_get_text_embedding_num_tokens_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) + == [] + ) + + def test_invoke_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.9]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query="q", + docs=["doc-1"], + score_threshold=0.2, + top_n=5, + ) + + assert result is rerank_result + + def test_invoke_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke rerank"): + client.invoke_rerank("tenant-1", "user-1", "org/plugin:1", "provider-a", "rerank-a", {}, "q", ["doc-1"]) + + def test_invoke_multimodal_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.8]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_multimodal_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query={"type": "text", "value": "q"}, + docs=[{"type": "image", "value": "doc"}], + score_threshold=0.1, + top_n=3, + ) + + assert result is rerank_result + + def test_invoke_multimodal_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke multimodal rerank"): + client.invoke_multimodal_rerank( + "tenant-1", + "user-1", + "org/plugin:1", + "provider-a", + "rerank-a", + {}, + {"type": "text"}, + [{"type": "image"}], + ) + + def test_invoke_tts(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="68656c6c6f"), SimpleNamespace(result="21")]), + ) + + result = list( + client.invoke_tts( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + content_text="hello", + voice="alloy", + ) + ) + + assert result == [b"hello", b"!"] + + def test_invoke_tts_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-400, message="tts error") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="tts error-400"): + list(client.invoke_tts("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}, "hello", "alloy")) + + def test_get_tts_model_voices(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter( + [ + SimpleNamespace( + voices=[ + SimpleNamespace(name="Alloy", value="alloy"), + SimpleNamespace(name="Echo", value="echo"), + ] + ) + ] + ), + ) + + result = client.get_tts_model_voices( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + language="en", + ) + + assert result == [{"name": "Alloy", "value": "alloy"}, {"name": "Echo", "value": "echo"}] + + def test_get_tts_model_voices_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.get_tts_model_voices("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}) == [] + + def test_invoke_speech_to_text(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="transcribed text")]), + ) + + result = client.invoke_speech_to_text( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="stt-a", + credentials={}, + file=io.BytesIO(b"abc"), + ) + + assert result == "transcribed text" + assert stream_mock.call_args.kwargs["data"]["data"]["file"] == "616263" + + def test_invoke_speech_to_text_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke speech to text"): + client.invoke_speech_to_text( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "stt-a", {}, io.BytesIO(b"abc") + ) + + def test_invoke_moderation(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True)]), + ) + + result = client.invoke_moderation( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="moderation-a", + credentials={}, + text="safe text", + ) + + assert result is True + assert stream_mock.call_args.kwargs["path"] == "plugin/tenant-1/dispatch/moderation/invoke" + + def test_invoke_moderation_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke moderation"): + client.invoke_moderation("tenant-1", "user-1", "org/plugin:1", "provider-a", "moderation-a", {}, "unsafe") diff --git a/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py new file mode 100644 index 0000000000..6fb4c99432 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py @@ -0,0 +1,147 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.impl.oauth import OAuthHandler + + +def _build_request(body: bytes = b"payload") -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/oauth/callback", + "QUERY_STRING": "code=123", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "HTTP_HOST": "localhost", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_X_TEST": "yes", + } + return Request(environ) + + +class TestOAuthHandler: + def test_get_authorization_url(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(authorization_url="https://auth.example.com")]), + ) + + response = handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + ) + + assert response.authorization_url == "https://auth.example.com" + assert stream_mock.call_count == 1 + + def test_get_authorization_url_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting authorization URL"): + handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + ) + + def test_get_credentials(self, mocker): + handler = OAuthHandler() + captured_data = {} + + def fake_stream(*args, **kwargs): + captured_data.update(kwargs["data"]) + return iter([SimpleNamespace(credentials={"token": "abc"}, metadata={}, expires_at=1)]) + + stream_mock = mocker.patch.object( + handler, "_request_with_plugin_daemon_response_stream", side_effect=fake_stream + ) + + response = handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + request=_build_request(), + ) + + assert response.credentials == {"token": "abc"} + assert "raw_http_request" in captured_data["data"] + assert stream_mock.call_count == 1 + + def test_get_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting credentials"): + handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + request=_build_request(), + ) + + def test_refresh_credentials(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(credentials={"token": "new"}, metadata={}, expires_at=1)]), + ) + + response = handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + credentials={"refresh_token": "r"}, + ) + + assert response.credentials == {"token": "new"} + assert stream_mock.call_count == 1 + + def test_refresh_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error refreshing credentials"): + handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + credentials={}, + ) + + def test_convert_request_to_raw_data(self): + handler = OAuthHandler() + request = _build_request(b"body-data") + + raw = handler._convert_request_to_raw_data(request) + + assert raw.startswith(b"POST /oauth/callback?code=123 HTTP/1.1\r\n") + assert b"X-Test: yes\r\n" in raw + assert raw.endswith(b"body-data") diff --git a/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py new file mode 100644 index 0000000000..80cf46f9bb --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py @@ -0,0 +1,121 @@ +from types import SimpleNamespace + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.tool import PluginToolManager + + +def _tool_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + tools=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginToolManager: + def test_fetch_tool_providers(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("remote") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.tools[0].identity.provider == "org/plugin/remote" + + def test_fetch_tool_provider(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("provider") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": {"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]} + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return provider + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.tools[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object( + manager, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk"]) + ) + merge_mock = mocker.patch("core.plugin.impl.tool.merge_blob_chunks", return_value=["merged"]) + + result = manager.invoke( + tenant_id="tenant-1", + user_id="user-1", + tool_provider="org/plugin/provider", + tool_name="search", + credentials={"api_key": "k"}, + credential_type=CredentialType.API_KEY, + tool_parameters={"q": "python"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" + + def test_validate_credentials_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + def test_get_runtime_parameters_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(parameters=[{"name": "p"}])]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [{"name": "p"}] + + stream_mock.return_value = iter([]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [] diff --git a/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py new file mode 100644 index 0000000000..76da51c2c8 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py @@ -0,0 +1,226 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.trigger import PluginTriggerClient +from core.trigger.entities.entities import Subscription +from models.provider_ids import TriggerProviderID + + +def _request() -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/events", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(b"payload"), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": "7", + "HTTP_HOST": "localhost", + } + return Request(environ) + + +def _subscription() -> Subscription: + return Subscription(expires_at=123, endpoint="https://example.com/hook", parameters={"a": 1}, properties={"p": 1}) + + +def _trigger_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + events=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +def _subscription_call_kwargs(method_name: str) -> dict: + if method_name == "subscribe": + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + "endpoint": "https://example.com/hook", + "parameters": {"k": "v"}, + } + + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "subscription": _subscription(), + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + } + + +class TestPluginTriggerClient: + def test_fetch_trigger_providers(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("remote") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "plugin_id": "org/plugin", + "provider": "remote", + "declaration": {"events": [{"identity": {"provider": "old"}}]}, + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.events[0].identity.provider == "org/plugin/remote" + + def test_fetch_trigger_provider(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("provider") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = {"data": {"declaration": {"events": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_provider("tenant-1", TriggerProviderID("org/plugin/provider")) + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.events[0].identity.provider == "org/plugin/provider" + + def test_invoke_trigger_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(variables={"ok": True}, cancelled=False)]), + ) + + result = client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + assert result.variables == {"ok": True} + assert stream_mock.call_count == 1 + + def test_invoke_trigger_event_no_response_raises(self, mocker): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="No response received from plugin daemon for invoke trigger"): + client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + def test_validate_provider_credentials(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + with pytest.raises( + ValueError, match="No response received from plugin daemon for validate provider credentials" + ): + client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) + + def test_dispatch_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(user_id="u", events=["e"])]), + ) + + result = client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result.user_id == "u" + assert stream_mock.call_count == 1 + + stream_mock.return_value = iter([]) + with pytest.raises(ValueError, match="No response received from plugin daemon for dispatch event"): + client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + @pytest.mark.parametrize("method_name", ["subscribe", "unsubscribe", "refresh"]) + def test_subscription_operations_success(self, mocker, method_name): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(subscription={"id": "sub"})]), + ) + + method = getattr(client, method_name) + result = method(**_subscription_call_kwargs(method_name)) + + assert result.subscription == {"id": "sub"} + assert stream_mock.call_count == 1 + + @pytest.mark.parametrize( + ("method_name", "expected"), + [ + ("subscribe", "No response received from plugin daemon for subscribe"), + ("unsubscribe", "No response received from plugin daemon for unsubscribe"), + ("refresh", "No response received from plugin daemon for refresh"), + ], + ) + def test_subscription_operations_no_response(self, mocker, method_name, expected): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + method = getattr(client, method_name) + + with pytest.raises(ValueError, match=expected): + method(**_subscription_call_kwargs(method_name)) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index a380149554..c2778f082b 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -1,72 +1,359 @@ +import json from types import SimpleNamespace from unittest.mock import MagicMock +import pytest +from pydantic import BaseModel + from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from models.model import AppMode -def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" +class _Chunk(BaseModel): + value: int - app = MagicMock() - app.mode = AppMode.ADVANCED_CHAT - app.workflow = workflow - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), +class TestBaseBackwardsInvocation: + def test_convert_to_event_stream_with_generator_and_error(self): + def _stream(): + yield _Chunk(value=1) + yield {"x": 2} + yield "ignored" + raise RuntimeError("boom") + + chunks = list(BaseBackwardsInvocation.convert_to_event_stream(_stream())) + + assert len(chunks) == 3 + first = json.loads(chunks[0].decode()) + second = json.loads(chunks[1].decode()) + error = json.loads(chunks[2].decode()) + assert first["data"]["value"] == 1 + assert second["data"]["x"] == 2 + assert error["error"] == "boom" + + def test_convert_to_event_stream_with_non_generator(self): + chunks = list(BaseBackwardsInvocation.convert_to_event_stream({"ok": True})) + payload = json.loads(chunks[0].decode()) + assert payload["data"] == {"ok": True} + assert payload["error"] == "" + + +class TestPluginAppBackwardsInvocation: + def test_fetch_app_info_workflow_path(self, mocker): + workflow = MagicMock() + workflow.features_dict = {"feature": "v"} + workflow.user_input_form.return_value = [{"name": "foo"}] + app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mapper = mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result == {"data": {"mapped": True}} + mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}]) + + def test_fetch_app_info_model_config_path(self, mocker): + model_config = MagicMock() + model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"} + app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result["data"] == {"mapped": True} + + @pytest.mark.parametrize( + ("mode", "route_method"), + [ + (AppMode.CHAT, "invoke_chat_app"), + (AppMode.ADVANCED_CHAT, "invoke_chat_app"), + (AppMode.AGENT_CHAT, "invoke_chat_app"), + (AppMode.WORKFLOW, "invoke_workflow_app"), + (AppMode.COMPLETION, "invoke_completion_app"), + ], ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, + def test_invoke_app_routes_by_mode(self, mocker, mode, route_method): + app = MagicMock(mode=mode) + user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user) + route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="hello", + stream=False, + inputs={"x": 1}, + files=[], + ) + + assert result == {"routed": True} + assert route.call_count == 1 + + def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker): + app = MagicMock(mode=AppMode.WORKFLOW) + end_user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + get_or_create = mocker.patch( + "core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user", + return_value=end_user, + ) + route = mocker.patch.object(PluginAppBackwardsInvocation, "invoke_workflow_app", return_value={"ok": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="", + tenant_id="tenant", + conversation_id="", + query=None, + stream=True, + inputs={}, + files=[], + ) + + assert result == {"ok": True} + get_or_create.assert_called_once_with(app) + assert route.call_args.args[1] is end_user + + def test_invoke_app_missing_query_for_chat_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT)) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="missing query"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_app_unexpected_mode_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode="other")) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="q", + stream=False, + inputs={}, + files=[], + ) + + @pytest.mark.parametrize( + ("mode", "generator_path"), + [ + (AppMode.AGENT_CHAT, "core.plugin.backwards_invocation.app.AgentChatAppGenerator.generate"), + (AppMode.CHAT, "core.plugin.backwards_invocation.app.ChatAppGenerator.generate"), + ], ) + def test_invoke_chat_app_agent_and_chat(self, mocker, mode, generator_path): + app = MagicMock(mode=mode, workflow=None) + spy = mocker.patch(generator_path, return_value={"result": "ok"}) - result = PluginAppBackwardsInvocation.invoke_chat_app( - app=app, - user=MagicMock(), - conversation_id="conv-1", - query="hello", - stream=False, - inputs={"k": "v"}, - files=[], - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + assert result == {"result": "ok"} + assert spy.call_count == 1 + def test_invoke_chat_app_advanced_chat_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" -def test_invoke_workflow_app_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" + app = MagicMock() + app.mode = AppMode.ADVANCED_CHAT + app.workflow = workflow - app = MagicMock() - app.mode = AppMode.WORKFLOW - app.workflow = workflow + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", + return_value={"result": "ok"}, + ) - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), - ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - result = PluginAppBackwardsInvocation.invoke_workflow_app( - app=app, - user=MagicMock(), - stream=False, - inputs={"k": "v"}, - files=[], - ) + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + def test_invoke_chat_app_advanced_chat_without_workflow_raises(self): + app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_chat_app_unexpected_mode_raises(self): + app = MagicMock(mode="invalid") + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_workflow_app_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.WORKFLOW + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" + + def test_invoke_workflow_app_without_workflow_raises(self): + app = MagicMock(mode=AppMode.WORKFLOW, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_completion_app(self, mocker): + spy = mocker.patch( + "core.plugin.backwards_invocation.app.CompletionAppGenerator.generate", return_value={"ok": 1} + ) + app = MagicMock(mode=AppMode.COMPLETION) + + result = PluginAppBackwardsInvocation.invoke_completion_app(app, MagicMock(), False, {"x": 1}, []) + + assert result == {"ok": 1} + assert spy.call_count == 1 + + def test_get_user_returns_end_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [MagicMock(id="end-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "end-user" + + def test_get_user_falls_back_to_account_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, MagicMock(id="account-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "account-user" + + def test_get_user_raises_when_user_not_found(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, None] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + with pytest.raises(ValueError, match="user not found"): + PluginAppBackwardsInvocation._get_user("uid") + + def test_get_app_returns_app(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + app_obj = MagicMock(id="app") + query_chain.first.return_value = app_obj + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj + + def test_get_app_raises_when_missing(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + query_chain.first.return_value = None + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") + + def test_get_app_raises_when_query_fails(self, mocker): + db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py new file mode 100644 index 0000000000..b0b64a601b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -0,0 +1,347 @@ +import binascii +import datetime +from enum import StrEnum + +import pytest +from flask import Response +from pydantic import ValidationError + +from core.plugin.entities.endpoint import EndpointEntityWithInstance +from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import ( + RequestInvokeLLM, + RequestInvokeSpeech2Text, + TriggerDispatchResponse, + TriggerInvokeEventResponse, +) +from core.plugin.utils.http_parser import serialize_response +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) + + +class TestEndpointEntity: + def test_endpoint_entity_with_instance_renders_url(self, mocker): + mocker.patch("core.plugin.entities.endpoint.dify_config.ENDPOINT_URL_TEMPLATE", "https://dify.test/{hook_id}") + now = datetime.datetime.now(datetime.UTC) + + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + } + ) + + assert entity.url == "https://dify.test/hook-123" + + def test_endpoint_entity_with_instance_keeps_existing_url(self): + now = datetime.datetime.now(datetime.UTC) + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + "url": "https://preset.test/hook-123", + } + ) + assert entity.url == "https://preset.test/hook-123" + + +class TestMarketplaceEntities: + def test_marketplace_declaration_strips_empty_optional_fields(self): + declaration = MarketplacePluginDeclaration.model_validate( + { + "name": "plugin", + "org": "org", + "plugin_id": "org/plugin", + "icon": "icon.png", + "label": {"en_US": "Plugin"}, + "brief": {"en_US": "Brief"}, + "resource": {"memory": 256}, + "endpoint": {}, + "model": {}, + "tool": {}, + "latest_version": "1.0.0", + "latest_package_identifier": "org/plugin@1.0.0", + "status": "active", + "deprecated_reason": "", + "alternative_plugin_id": "", + } + ) + + assert declaration.endpoint is None + assert declaration.model is None + assert declaration.tool is None + + def test_marketplace_snapshot_computed_plugin_id(self): + snapshot = MarketplacePluginSnapshot( + org="langgenius", + name="search", + latest_version="1.0.0", + latest_package_identifier="langgenius/search@1.0.0", + latest_package_url="https://example.com/pkg", + ) + assert snapshot.plugin_id == "langgenius/search" + + +class TestPluginParameterEntities: + def _label(self) -> I18nObject: + return I18nObject(en_US="label") + + def test_parameter_option_value_casts_to_string(self): + option = PluginParameterOption(value=123, label=self._label()) + assert option.value == "123" + + def test_plugin_parameter_options_non_list_defaults_to_empty(self): + parameter = PluginParameter(name="p", label=self._label(), options="invalid") # type: ignore[arg-type] + assert parameter.options == [] + + @pytest.mark.parametrize( + ("parameter_type", "expected"), + [ + (PluginParameterType.SECRET_INPUT, "string"), + (PluginParameterType.SELECT, "string"), + (PluginParameterType.CHECKBOX, "string"), + (PluginParameterType.NUMBER, PluginParameterType.NUMBER.value), + ], + ) + def test_as_normal_type(self, parameter_type, expected): + assert as_normal_type(parameter_type) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [(None, ""), (1, "1"), ("abc", "abc")], + ) + def test_cast_parameter_value_string_like(self, value, expected): + assert cast_parameter_value(PluginParameterType.STRING, value) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + ("true", True), + ("yes", True), + ("1", True), + ("false", False), + ("0", False), + ("random", True), + (1, True), + (0, False), + ], + ) + def test_cast_parameter_value_boolean(self, value, expected): + assert cast_parameter_value(PluginParameterType.BOOLEAN, value) is expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (1, 1), + (1.5, 1.5), + ("2", 2), + ("2.5", 2.5), + ], + ) + def test_cast_parameter_value_number(self, value, expected): + assert cast_parameter_value(PluginParameterType.NUMBER, value) == expected + + def test_cast_parameter_value_file_and_files(self): + assert cast_parameter_value(PluginParameterType.FILES, "f1") == ["f1"] + assert cast_parameter_value(PluginParameterType.SYSTEM_FILES, ["f1", "f2"]) == ["f1", "f2"] + assert cast_parameter_value(PluginParameterType.FILE, ["one"]) == "one" + assert cast_parameter_value(PluginParameterType.FILE, "one") == "one" + with pytest.raises(ValueError, match="only accepts one file"): + cast_parameter_value(PluginParameterType.FILE, ["a", "b"]) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.MODEL_SELECTOR, {"m": "gpt"}, {"m": "gpt"}), + (PluginParameterType.APP_SELECTOR, {"app": "a"}, {"app": "a"}), + (PluginParameterType.TOOLS_SELECTOR, [], []), + (PluginParameterType.ANY, {"k": "v"}, {"k": "v"}), + ], + ) + def test_cast_parameter_value_selectors_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "message"), + [ + (PluginParameterType.MODEL_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.APP_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.TOOLS_SELECTOR, "bad", "tools selector must be a list"), + (PluginParameterType.ANY, object(), "var selector must be"), + ], + ) + def test_cast_parameter_value_selectors_invalid(self, parameter_type, value, message): + with pytest.raises(ValueError, match=message): + cast_parameter_value(parameter_type, value) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, [1, 2], [1, 2]), + (PluginParameterType.ARRAY, "[1, 2]", [1, 2]), + (PluginParameterType.OBJECT, {"k": "v"}, {"k": "v"}), + (PluginParameterType.OBJECT, '{"a":1}', {"a": 1}), + ], + ) + def test_cast_parameter_value_array_and_object_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, "bad-json", ["bad-json"]), + (PluginParameterType.OBJECT, "bad-json", {}), + ], + ) + def test_cast_parameter_value_array_and_object_invalid_json_fallback(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + def test_cast_parameter_value_default_branch_and_wrapped_exception(self): + class _Unknown(StrEnum): + CUSTOM = "custom" + + assert cast_parameter_value(_Unknown.CUSTOM, 12) == "12" + + class _BadString: + def __str__(self): + raise RuntimeError("boom") + + with pytest.raises( + ValueError, + match=r"The tool parameter value <.*_BadString object at .* is not in correct type of string\.", + ): + cast_parameter_value(PluginParameterType.STRING, _BadString()) + + def test_init_frontend_parameter(self): + rule = PluginParameter( + name="choice", + label=self._label(), + required=True, + default="a", + options=[PluginParameterOption(value="a", label=self._label())], + ) + + assert init_frontend_parameter(rule, PluginParameterType.SELECT, None) == "a" + assert init_frontend_parameter(rule, PluginParameterType.NUMBER, 0) == 0 + with pytest.raises(ValueError, match="not in options"): + init_frontend_parameter(rule, PluginParameterType.SELECT, "b") + + required_rule = PluginParameter(name="required", label=self._label(), required=True, default=None) + with pytest.raises(ValueError, match="not found in tool config"): + init_frontend_parameter(required_rule, PluginParameterType.STRING, None) + + +class TestPluginDaemonEntities: + def test_credential_type_helpers(self): + assert CredentialType.API_KEY.get_name() == "API KEY" + assert CredentialType.OAUTH2.get_name() == "AUTH" + assert CredentialType.UNAUTHORIZED.get_name() == "UNAUTHORIZED" + + class _FakeCredential: + value = "custom-type" + + assert CredentialType.get_name(_FakeCredential()) == "CUSTOM TYPE" + assert CredentialType.API_KEY.is_editable() is True + assert CredentialType.OAUTH2.is_editable() is False + assert CredentialType.API_KEY.is_validate_allowed() is True + assert CredentialType.UNAUTHORIZED.is_validate_allowed() is False + assert set(CredentialType.values()) == {"api-key", "oauth2", "unauthorized"} + + @pytest.mark.parametrize( + ("raw", "expected"), + [ + ("api-key", CredentialType.API_KEY), + ("api_key", CredentialType.API_KEY), + ("oauth2", CredentialType.OAUTH2), + ("oauth", CredentialType.OAUTH2), + ("unauthorized", CredentialType.UNAUTHORIZED), + ], + ) + def test_credential_type_of(self, raw, expected): + assert CredentialType.of(raw) == expected + + def test_credential_type_of_invalid(self): + with pytest.raises(ValueError, match="Invalid credential type"): + CredentialType.of("invalid") + + +class TestPluginRequestEntities: + def test_request_invoke_llm_converts_prompt_messages(self): + payload = RequestInvokeLLM( + provider="openai", + model="gpt-4", + mode="chat", + prompt_messages=[ + {"role": "user", "content": "u"}, + {"role": "assistant", "content": "a"}, + {"role": "system", "content": "s"}, + {"role": "tool", "content": "t", "tool_call_id": "call-1"}, + ], + ) + + assert isinstance(payload.prompt_messages[0], UserPromptMessage) + assert isinstance(payload.prompt_messages[1], AssistantPromptMessage) + assert isinstance(payload.prompt_messages[2], SystemPromptMessage) + assert isinstance(payload.prompt_messages[3], ToolPromptMessage) + + def test_request_invoke_llm_prompt_messages_must_be_list(self): + with pytest.raises(ValidationError): + RequestInvokeLLM(provider="openai", model="gpt-4", mode="chat", prompt_messages="invalid") # type: ignore[arg-type] + + def test_request_invoke_speech2text_hex_conversion_and_error(self): + payload = RequestInvokeSpeech2Text(provider="openai", model="m", file=binascii.hexlify(b"abc").decode()) + assert payload.file == b"abc" + with pytest.raises(ValidationError): + RequestInvokeSpeech2Text(provider="openai", model="m", file=b"abc") # type: ignore[arg-type] + + def test_trigger_invoke_event_response_variables_conversion(self): + converted = TriggerInvokeEventResponse(variables='{"a": 1}', cancelled=False) + assert converted.variables == {"a": 1} + passthrough = TriggerInvokeEventResponse(variables={"b": 2}, cancelled=True) + assert passthrough.variables == {"b": 2} + + def test_trigger_dispatch_response_convert_response(self): + response = Response("ok", status=202, headers={"X-Req": "1"}) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.response.status_code == 202 + assert parsed.response.get_data() == b"ok" + with pytest.raises(ValidationError): + TriggerDispatchResponse(user_id="u", events=["e"], response="not-hex") + + def test_trigger_dispatch_response_payload_default(self): + response = Response("ok", status=200) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.payload == {} diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index e0eace0f2d..c7e94aa4cf 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -4,7 +4,10 @@ import pytest from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.plugin.utils.converter import convert_parameters_to_plugin_format +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File class TestChunkMerger: @@ -458,3 +461,89 @@ class TestChunkMerger: assert len(result) == 1 assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage) assert result[0].message.blob == b"FirstSecondThird" + + +class TestConverter: + def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): + file_param = File( + tenant_id="tenant-1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/file.png", + storage_key="", + ) + selector = ToolSelector( + provider_id="org/plugin/provider", + credential_id=None, + tool_name="search", + tool_description="search tool", + tool_configuration={"k": "v"}, + tool_parameters={ + "query": ToolSelector.Parameter( + name="query", + type=ToolParameter.ToolParameterType.STRING, + required=True, + description="query", + default="python", + options=[], + ) + }, + ) + params = {"file": file_param, "selector": selector, "plain": 123} + + converted = convert_parameters_to_plugin_format(params) + + assert converted["file"]["url"] == "https://example.com/file.png" + assert converted["selector"]["provider_id"] == "org/plugin/provider" + assert converted["plain"] == 123 + + def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): + file_one = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/a.txt", + storage_key="", + ) + file_two = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/b.txt", + storage_key="", + ) + selector_one = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-1", + tool_name="t1", + tool_description="tool 1", + tool_configuration={}, + tool_parameters={}, + ) + selector_two = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-2", + tool_name="t2", + tool_description="tool 2", + tool_configuration={}, + tool_parameters={}, + ) + + params = { + "files": [file_one, file_two], + "selectors": [selector_one, selector_two], + "empty_list": [], + "mixed_list": [file_one, "raw"], + "none_value": None, + } + + converted = convert_parameters_to_plugin_format(params) + + assert [item["url"] for item in converted["files"]] == [ + "https://example.com/a.txt", + "https://example.com/b.txt", + ] + assert [item["tool_name"] for item in converted["selectors"]] == ["t1", "t2"] + assert converted["empty_list"] == [] + assert converted["mixed_list"] == [file_one, "raw"] + assert converted["none_value"] is None diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py index 1c2e0c96f8..71144695bc 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -381,6 +381,54 @@ class TestEdgeCases: assert response.status_code == 200 assert response.get_data() == binary_body + def test_deserialize_request_with_lf_only_newlines(self): + raw_data = b"POST /lf-only?x=1 HTTP/1.1\nHost: localhost\nX-Test: yes\n\npayload" + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/lf-only" + assert request.args.get("x") == "1" + assert request.headers.get("X-Test") == "yes" + assert request.get_data() == b"payload" + + def test_deserialize_request_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"GET /no-separator HTTP/1.1\nHost: localhost\nInvalidHeader\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/no-separator" + assert request.headers.get("Host") == "localhost" + assert request.headers.get("InvalidHeader") is None + + def test_deserialize_request_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP request"): + deserialize_request(b"") + + def test_deserialize_response_with_lf_only_newlines(self): + raw_data = b"HTTP/1.1 202 Accepted\nX-Test: yes\n\nbody" + + response = deserialize_response(raw_data) + + assert response.status_code == 202 + assert response.headers.get("X-Test") == "yes" + assert response.get_data() == b"body" + + def test_deserialize_response_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"HTTP/1.1 204 No Content\nX-Test: yes\nInvalidHeader\n" + + response = deserialize_response(raw_data) + + assert response.status_code == 204 + assert response.headers.get("X-Test") == "yes" + assert response.headers.get("InvalidHeader") is None + assert response.get_data() == b"" + + def test_deserialize_response_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP response"): + deserialize_response(b"") + class TestFileUploads: def test_serialize_request_with_text_file_upload(self): diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3e184cbf21..3d08525aba 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -1,3 +1,4 @@ +from typing import cast from unittest.mock import MagicMock, patch import pytest @@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) from models.model import Conversation @@ -188,3 +191,328 @@ def get_chat_model_args(): context = "I am superman." return model_config_mock, memory_config, prompt_messages, inputs, context + + +def test_get_prompt_dispatches_completion_and_chat_and_invalid(): + transform = AdvancedPromptTransform() + model_config = MagicMock(spec=ModelConfigEntity) + completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic") + chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")] + + transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")]) + transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")]) + + completion_result = transform.get_prompt( + prompt_template=completion_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert completion_result[0].content == "c" + + chat_result = transform.get_prompt( + prompt_template=chat_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert chat_result[0].content == "h" + + invalid_result = transform.get_prompt( + prompt_template=cast(list, ["not-chat-model-message"]), + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert invalid_result == [] + + +def test_completion_prompt_jinja2_with_files(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") + + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with ( + patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"), + patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content, + ): + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_completion_model_prompt_messages( + prompt_template=completion_template, + inputs={"name": "John"}, + query="", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0].content, list) + assert messages[0].content[0].data == "https://example.com/image.jpg" + assert isinstance(messages[0].content[1], TextPromptMessageContent) + assert messages[0].content[1].data == "Hi John" + + +def test_completion_prompt_basic_sets_query_variable(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic") + + messages = transform._get_completion_model_prompt_messages( + prompt_template=template, + inputs={}, + query="what?", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert messages[0].content == "Q=what?" + + +def test_chat_prompt_with_variable_template_and_context(): + transform = AdvancedPromptTransform(with_variable_tmpl=True) + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"#node.name#": "john"}, + query=None, + files=[], + context="context-text", + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0], SystemPromptMessage) + assert messages[0].content == "sys=john ctx=context-text" + + +def test_chat_prompt_jinja2_branch_and_invalid_edition(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")] + + with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"): + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"name": "John"}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert messages[0].content == "Hello John" + + bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")] + with pytest.raises(ValueError, match="Invalid edition type"): + transform._get_chat_model_prompt_messages( + prompt_template=bad_prompt_template, + inputs={}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + +def test_chat_prompt_query_template_and_query_only_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + query_prompt_template="query={{#sys.query#}} ctx={{#context#}}", + ) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="what", + files=[], + context="ctx", + memory_config=memory_config, + memory=None, + model_config=model_config_mock, + ) + assert messages[-1].content == "query={{#sys.query#}} ctx=ctx" + + +def test_chat_prompt_memory_with_files_and_query(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) + memory = MagicMock(spec=TokenBufferMemory) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + transform._append_chat_histories = MagicMock( + side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages + ) + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="q", + files=[file], + context=None, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "q" + + +def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_with_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "u" + + prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_without_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1], UserPromptMessage) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "" + + +def test_chat_prompt_files_with_query_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=[], + inputs={}, + query="query-text", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "query-text" + + +def test_set_context_query_histories_variable_helpers(): + transform = AdvancedPromptTransform() + parser_context = PromptTemplateParser(template="{{#context#}}") + parser_query = PromptTemplateParser(template="{{#query#}}") + parser_hist = PromptTemplateParser(template="{{#histories#}}") + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ) + + assert transform._set_context_variable(None, parser_context, {})["#context#"] == "" + assert transform._set_query_variable("", parser_query, {})["#query#"] == "" + assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x" + assert ( + transform._set_histories_variable( + memory=None, # type: ignore[arg-type] + memory_config=memory_config, + raw_prompt="{{#histories#}}", + role_prefix=memory_config.role_prefix, # type: ignore[arg-type] + parser=parser_hist, + prompt_inputs={}, + model_config=model_config_mock, + )["#histories#"] + == "" + ) diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py index e3e500e310..1b114b369a 100644 --- a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -2,12 +2,14 @@ from uuid import uuid4 from constants import UUID_NIL from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length class MockMessage: - def __init__(self, id, parent_message_id): + def __init__(self, id, parent_message_id, answer="answer"): self.id = id self.parent_message_id = parent_message_id + self.answer = answer def __getitem__(self, item): return getattr(self, item) @@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages(): result = extract_thread_messages(messages) assert len(result) == 4 assert [msg["id"] for msg in result] == [id5, id4, id2, id1] + + +def test_extract_thread_messages_breaks_when_parent_is_none(): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)] + + result = extract_thread_messages(messages) + + assert len(result) == 1 + assert result[0].id == id2 + + +def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer=""), # newest generated message should be excluded + MockMessage(id1, UUID_NIL, answer="ok"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-1") + + assert length == 1 + mock_scalars.assert_called_once() + + +def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer="latest-answer"), + MockMessage(id1, UUID_NIL, answer="older-answer"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-2") + + assert length == 2 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 4136816562..9fc300348a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,11 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) @@ -25,3 +30,82 @@ def test_dump_prompt_message(): ) data = prompt.model_dump() assert data["content"][0].get("url") == example_url + + +def test_prompt_messages_to_prompt_for_saving_chat_mode(): + chat_messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="hello "), + ImagePromptMessageContent( + url="https://example.com/image1.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + AudioPromptMessageContent( + url="https://example.com/audio1.mp3", + format="mp3", + mime_type="audio/mpeg", + ), + TextPromptMessageContent(data="world"), + ] + ), + AssistantPromptMessage( + content="assistant-text", + tool_calls=[ + { + "id": "tool-1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"python"}'}, + } + ], + ), + ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"), + UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type] + ] + + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages) + + assert len(prompts) == 3 + assert prompts[0]["role"] == "user" + assert prompts[0]["text"] == "hello world" + assert prompts[0]["files"][0]["type"] == "image" + assert prompts[0]["files"][1]["type"] == "audio" + + assert prompts[1]["role"] == "assistant" + assert prompts[1]["text"] == "assistant-text" + assert prompts[1]["tool_calls"][0]["function"]["name"] == "search" + assert prompts[2]["role"] == "tool" + + +def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files(): + completion_message_with_files = UserPromptMessage( + content=[ + TextPromptMessageContent(data="first "), + TextPromptMessageContent(data="second"), + ImagePromptMessageContent( + url="https://example.com/image2.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.LOW, + ), + ] + ) + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_with_files] + ) + assert prompts == [ + { + "role": "user", + "text": "first second", + "files": prompts[0]["files"], + } + ] + assert prompts[0]["files"][0]["type"] == "image" + + completion_message_text_only = UserPromptMessage(content="plain text") + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_text_only] + ) + assert prompts == [{"role": "user", "text": "plain text"}] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 7976120547..d379e3067a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,4 +1,10 @@ -# from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -9,44 +15,217 @@ # from core.prompt.prompt_transform import PromptTransform -# def test__calculate_rest_token(): -# model_schema_mock = MagicMock(spec=AIModelEntity) -# parameter_rule_mock = MagicMock(spec=ParameterRule) -# parameter_rule_mock.name = "max_tokens" -# model_schema_mock.parameter_rules = [parameter_rule_mock] -# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} +class TestPromptTransform: + def test_resolve_model_runtime_requires_model_config_or_instance(self): + transform = PromptTransform() -# large_language_model_mock = MagicMock(spec=LargeLanguageModel) -# large_language_model_mock.get_num_tokens.return_value = 6 + with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."): + transform._resolve_model_runtime() -# provider_mock = MagicMock(spec=ProviderEntity) -# provider_mock.provider = "openai" + def test_resolve_model_runtime_builds_model_instance_from_model_config(self): + transform = PromptTransform() + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = fake_model_schema + fake_model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials=None, + parameters=None, + stop=None, + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="config-model", + credentials={"api_key": "secret"}, + parameters={"temperature": 0.1}, + stop=["END"], + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + ) -# provider_configuration_mock = MagicMock(spec=ProviderConfiguration) -# provider_configuration_mock.provider = provider_mock -# provider_configuration_mock.model_settings = None + with patch( + "core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance + ) as model_instance_cls: + model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config) -# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) -# provider_model_bundle_mock.model_type_instance = large_language_model_mock -# provider_model_bundle_mock.configuration = provider_configuration_mock + model_instance_cls.assert_called_once_with( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model, + ) + fake_model_type_instance.get_model_schema.assert_called_once_with( + model="resolved-model", + credentials={"api_key": "secret"}, + ) + assert model_instance is fake_model_instance + assert model_instance.credentials == {"api_key": "secret"} + assert model_instance.parameters == {"temperature": 0.1} + assert model_instance.stop == ["END"] + assert model_schema is fake_model_schema -# model_config_mock = MagicMock(spec=ModelConfigEntity) -# model_config_mock.model = "gpt-4" -# model_config_mock.credentials = {} -# model_config_mock.parameters = {"max_tokens": 50} -# model_config_mock.model_schema = model_schema_mock -# model_config_mock.provider_model_bundle = provider_model_bundle_mock + def test_resolve_model_runtime_uses_model_config_schema_fallback(self): + transform = PromptTransform() + fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + model_config = SimpleNamespace(model_schema=fallback_schema) -# prompt_transform = PromptTransform() + resolved_model_instance, resolved_schema = transform._resolve_model_runtime( + model_config=model_config, + model_instance=model_instance, + ) -# prompt_messages = [UserPromptMessage(content="Hello, how are you?")] -# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + assert resolved_model_instance is model_instance + assert resolved_schema is fallback_schema -# # Validate based on the mock configuration and expected logic -# expected_rest_tokens = ( -# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] -# - model_config_mock.parameters["max_tokens"] -# - large_language_model_mock.get_num_tokens.return_value -# ) -# assert rest_tokens == expected_rest_tokens -# assert rest_tokens == 6 + def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self): + transform = PromptTransform() + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + + with pytest.raises(ValueError, match="Model schema not found for the provided model instance."): + transform._resolve_model_runtime(model_instance=model_instance) + + def test_calculate_rest_token_defaults_when_context_size_missing(self): + transform = PromptTransform() + fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0) + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + provider_model_bundle=object(), + model="test-model", + parameters={}, + ) + + rest = transform._calculate_rest_token([], model_config=model_config) + + assert rest == 2000 + + def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="max_tokens", use_template=None) + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 50}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 0 + + def test_calculate_rest_token_supports_use_template_parameter(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens") + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 30}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 150 + + def test_get_history_messages_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_text.return_value = "history" + + memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3)) + result = transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_with_window, + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert result == "history" + memory.get_history_prompt_text.assert_called_with( + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + message_limit=3, + ) + + memory.reset_mock() + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2)) + transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_no_window, + max_token_limit=50, + ) + memory.get_history_prompt_text.assert_called_with(max_token_limit=50) + + def test_get_history_messages_list_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_messages.return_value = ["m1", "m2"] + + memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120) + assert result == ["m1", "m2"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2) + + memory.reset_mock() + memory.get_history_prompt_messages.return_value = ["only"] + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10) + assert result == ["only"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None) + + def test_append_chat_histories_extends_prompt_messages(self, monkeypatch): + transform = PromptTransform() + memory = MagicMock() + memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None)) + + monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99) + monkeypatch.setattr( + transform, + "_get_history_messages_list_from_memory", + lambda memory, memory_config, max_token_limit: ["h1", "h2"], + ) + + result = transform._append_chat_histories( + memory=memory, + memory_config=memory_config, + prompt_messages=["p1"], + model_config=SimpleNamespace(), + ) + + assert result == ["p1", "h1", "h2"] diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 2ef66e8a96..e6d28224d7 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,9 +1,29 @@ -from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation @@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages(): assert len(prompt_messages) == 1 assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt + + +def test_get_prompt_dispatches_chat_and_completion(): + transform = SimplePromptTransform() + model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_chat.mode = "chat" + model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_completion.mode = "completion" + prompt_entity = SimpleNamespace(simple_prompt_template="hello") + + transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None)) + transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"])) + + chat_messages, chat_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_chat, + ) + assert chat_messages == ["chat-msg"] + assert chat_stops is None + + completion_messages, completion_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_completion, + ) + assert completion_messages == ["completion-msg"] + assert completion_stops == ["stop"] + + +def test_get_prompt_str_and_rules_type_validation_errors(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config.provider = "openai" + model_config.model = "gpt-4" + valid_prompt_template = SimplePromptTransform().get_prompt_template( + AppMode.CHAT, "openai", "gpt-4", "", False, False + )["prompt_template"] + + bad_custom_keys = { + "prompt_template": valid_prompt_template, + "custom_variable_keys": "not-list", + "special_variable_keys": [], + "prompt_rules": {}, + } + transform.get_prompt_template = MagicMock(return_value=bad_custom_keys) + with pytest.raises(TypeError, match="custom_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_special_keys = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": "not-list", + } + transform.get_prompt_template = MagicMock(return_value=bad_special_keys) + with pytest.raises(TypeError, match="special_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_template = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": 123, + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_template) + with pytest.raises(TypeError, match="PromptTemplateParser"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_rules = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": valid_prompt_template, + "prompt_rules": "not-dict", + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules) + with pytest.raises(TypeError, match="prompt_rules"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + +def test_chat_model_prompt_messages_uses_prompt_when_query_empty(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {})) + transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text")) + + prompt_messages, _ = transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert prompt_messages[0].content == "prompt-text" + transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None) + + +def test_completion_model_prompt_messages_empty_stops_becomes_none(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []})) + + prompt_messages, stops = transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert len(prompt_messages) == 1 + assert stops is None + + +def test_get_last_user_message_with_files_and_context_files(): + transform = SimplePromptTransform() + file = SimpleNamespace() + context_file = SimpleNamespace() + + with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.side_effect = [ + ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"), + ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"), + ] + message = transform._get_last_user_message( + prompt="hello", + files=[file], + context_files=[context_file], + image_detail_config=None, + ) + + assert isinstance(message.content, list) + assert message.content[0].data == "https://example.com/a.jpg" + assert message.content[1].data == "https://example.com/b.jpg" + assert isinstance(message.content[2], TextPromptMessageContent) + assert message.content[2].data == "hello" + + +def test_prompt_file_name_branches(): + transform = SimplePromptTransform() + + assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat" + assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion" + assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion" + assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat" + + +def test_advanced_prompt_templates_constants_are_importable(): + assert isinstance(CONTEXT, str) + assert isinstance(BAICHUAN_CONTEXT, str) + assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index d47d4d6130..91259c9a45 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,11 +1,14 @@ import dataclasses +import orjson +import pytest from pydantic import BaseModel from core.helper import encrypter from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime import VariablePool from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, @@ -23,6 +26,11 @@ from dify_graph.variables.segments import ( get_segment_discriminator, ) from dify_graph.variables.types import SegmentType +from dify_graph.variables.utils import ( + dumps_with_segments, + segment_orjson_default, + to_selector, +) from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, @@ -379,3 +387,125 @@ class TestSegmentDumpAndLoad: assert get_segment_discriminator("not_a_dict") is None assert get_segment_discriminator(42) is None assert get_segment_discriminator(object) is None + + +class TestSegmentAdditionalProperties: + def test_base_segment_text_log_markdown_size_and_to_object(self): + """Ensure StringSegment exposes text, log, markdown, size and to_object.""" + segment = StringSegment(value="hello") + + assert segment.text == "hello" + assert segment.log == "hello" + assert segment.markdown == "hello" + assert segment.size > 0 + assert segment.to_object() == "hello" + + def test_none_segment_empty_outputs(self): + """Ensure NoneSegment renders empty text, log and markdown.""" + segment = NoneSegment() + + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + + def test_object_segment_json_outputs(self): + """Ensure ObjectSegment renders JSON output for text, log and markdown.""" + segment = ObjectSegment(value={"key": "值", "n": 1}) + + assert segment.text == '{"key": "值", "n": 1}' + assert segment.log == '{\n "key": "值",\n "n": 1\n}' + assert segment.markdown == '{\n "key": "值",\n "n": 1\n}' + + def test_array_segment_text_and_markdown(self): + """Ensure ArrayAnySegment handles empty/non-empty text and markdown rendering.""" + empty_segment = ArrayAnySegment(value=[]) + non_empty_segment = ArrayAnySegment(value=[1, "two"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == "[1, 'two']" + assert non_empty_segment.markdown == "- 1\n- two" + + def test_file_segment_properties(self): + """Ensure FileSegment markdown, text and log fields match expected behavior.""" + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="doc.txt") + segment = FileSegment(value=file) + + assert segment.markdown == "[doc.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + def test_array_string_segment_text_branches(self): + """Ensure ArrayStringSegment text handling for empty and non-empty values.""" + empty_segment = ArrayStringSegment(value=[]) + non_empty_segment = ArrayStringSegment(value=["hello", "世界"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == '["hello", "世界"]' + + def test_array_file_segment_markdown_and_empty_text_log(self): + """Ensure ArrayFileSegment markdown renders links and text/log stay empty.""" + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + segment = ArrayFileSegment(value=[file1, file2]) + + assert segment.markdown == "[a.txt](https://example.com/file.txt)\n[b.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + +class TestSegmentGroupAdditional: + def test_segment_group_markdown_and_to_object(self): + group = SegmentGroup(value=[StringSegment(value="A"), NoneSegment(), StringSegment(value="B")]) + + assert group.markdown == "AB" + assert group.to_object() == ["A", None, "B"] + + +class TestSegmentUtils: + def test_to_selector_without_paths(self): + assert to_selector("node-1", "output") == ["node-1", "output"] + + def test_to_selector_with_paths(self): + assert to_selector("node-1", "output", ("a", "b")) == ["node-1", "output", "a", "b"] + + def test_array_file_segment_serialization(self): + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + + result = segment_orjson_default(ArrayFileSegment(value=[file1, file2])) + + assert len(result) == 2 + assert result[0]["filename"] == "a.txt" + assert result[1]["filename"] == "b.txt" + + def test_file_segment_serialization(self): + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="single.txt") + + result = segment_orjson_default(FileSegment(value=file)) + + assert result["filename"] == "single.txt" + assert result["remote_url"] == "https://example.com/file.txt" + + def test_segment_group_and_segment_serialization(self): + group = SegmentGroup(value=[StringSegment(value="a"), StringSegment(value="b")]) + + assert segment_orjson_default(group) == ["a", "b"] + assert segment_orjson_default(StringSegment(value="value")) == "value" + + def test_segment_orjson_default_unsupported_type(self): + with pytest.raises(TypeError, match="not JSON serializable"): + segment_orjson_default(object()) + + def test_dumps_with_segments(self): + data = { + "segment": StringSegment(value="hello"), + "group": SegmentGroup(value=[StringSegment(value="x"), StringSegment(value="y")]), + 1: "numeric-key", + } + + dumped = dumps_with_segments(data) + loaded = orjson.loads(dumped) + + assert loaded["segment"] == "hello" + assert loaded["group"] == ["x", "y"] + assert loaded["1"] == "numeric-key" diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index c3371d92e3..bb234d9bbd 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,7 @@ import pytest +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import ArrayValidation, SegmentType @@ -69,22 +71,36 @@ class TestSegmentTypeIsValidArrayValidation: """ def test_array_validation_all_success(self): + # Arrange value = ["hello", "world", "foo"] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert is_valid def test_array_validation_all_fail(self): + # Arrange value = ["hello", 123, "world"] - # Should return False, since 123 is not a string - assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert not is_valid def test_array_validation_first(self): + # Arrange value = ["hello", 123, None] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Assert + assert is_valid def test_array_validation_none(self): + # Arrange value = [1, 2, 3] - # validation is None, skip - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Assert + assert is_valid class TestSegmentTypeGetZeroValue: @@ -163,3 +179,62 @@ class TestSegmentTypeGetZeroValue: for seg_type in unsupported_types: with pytest.raises(ValueError, match="unsupported variable type"): SegmentType.get_zero_value(seg_type) + + +class TestSegmentTypeInferSegmentType: + @pytest.mark.parametrize( + ("value", "expected"), + [ + ([], SegmentType.ARRAY_NUMBER), + ([1, 2, 3], SegmentType.ARRAY_NUMBER), + ([1, 2.5], SegmentType.ARRAY_NUMBER), + (["a", "b"], SegmentType.ARRAY_STRING), + ([{"k": "v"}], SegmentType.ARRAY_OBJECT), + ([None], SegmentType.ARRAY_ANY), + ([True, False], SegmentType.ARRAY_BOOLEAN), + ([[1], [2]], SegmentType.ARRAY_ANY), + ([1, "a"], SegmentType.ARRAY_ANY), + (None, SegmentType.NONE), + (True, SegmentType.BOOLEAN), + (1, SegmentType.INTEGER), + (1.2, SegmentType.FLOAT), + ("abc", SegmentType.STRING), + ({"k": "v"}, SegmentType.OBJECT), + ], + ) + def test_infer_segment_type_supported_values(self, value, expected): + assert SegmentType.infer_segment_type(value) == expected + + +class TestSegmentTypeAdditionalMethods: + def test_cast_value_for_bool_number_and_array_number(self): + assert SegmentType.cast_value(True, SegmentType.INTEGER) == 1 + assert SegmentType.cast_value(False, SegmentType.NUMBER) == 0 + assert SegmentType.cast_value([True, False], SegmentType.ARRAY_NUMBER) == [1, 0] + + mixed = [True, 1] + assert SegmentType.cast_value(mixed, SegmentType.ARRAY_NUMBER) is mixed + assert SegmentType.cast_value("x", SegmentType.STRING) == "x" + + def test_exposed_type_and_element_type(self): + assert SegmentType.INTEGER.exposed_type() == SegmentType.NUMBER + assert SegmentType.FLOAT.exposed_type() == SegmentType.NUMBER + assert SegmentType.STRING.exposed_type() == SegmentType.STRING + + assert SegmentType.ARRAY_STRING.element_type() == SegmentType.STRING + assert SegmentType.ARRAY_ANY.element_type() is None + + with pytest.raises(ValueError, match="element_type is only supported by array type"): + SegmentType.STRING.element_type() + + def test_group_validation_for_segment_group_and_list(self): + valid_group = SegmentGroup(value=[StringSegment(value="a")]) + assert SegmentType.GROUP.is_valid(valid_group) is True + assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True + assert SegmentType.GROUP.is_valid(["not-segment"]) is False + + def test_unreachable_assertion_branch(self, monkeypatch): + monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False) + + with pytest.raises(AssertionError, match="unreachable"): + SegmentType.ARRAY_STRING.is_valid(["a"])