test: unit test cases for core.variables, core.plugin, core.prompt module (#32637)

This commit is contained in:
Rajat Agarwal 2026-03-12 08:59:02 +05:30 committed by GitHub
parent 135b3a15a6
commit 07e19c0748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 3526 additions and 97 deletions

View File

@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
except ValueError:
raise
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):

View File

@ -0,0 +1,91 @@
from types import SimpleNamespace
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
def _agent_provider(name: str = "agent") -> SimpleNamespace:
return SimpleNamespace(
plugin_id="org/plugin",
declaration=SimpleNamespace(
identity=SimpleNamespace(name=name),
strategies=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
),
)
class TestPluginAgentClient:
def test_fetch_agent_strategy_providers(self, mocker):
client = PluginAgentClient()
provider = _agent_provider("remote")
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": [
{
"declaration": {
"identity": {"name": "remote"},
"strategies": [{"identity": {"provider": "old"}}],
}
}
]
}
transformed = transformer(payload)
assert transformed["data"][0]["declaration"]["strategies"][0]["identity"]["provider"] == "remote"
return [provider]
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = client.fetch_agent_strategy_providers("tenant-1")
assert request_mock.call_count == 1
assert result[0].declaration.identity.name == "org/plugin/remote"
assert result[0].declaration.strategies[0].identity.provider == "org/plugin/remote"
def test_fetch_agent_strategy_provider(self, mocker):
client = PluginAgentClient()
provider = _agent_provider("provider")
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
assert transformer({"data": None}) == {"data": None}
payload = {"data": {"declaration": {"strategies": [{"identity": {"provider": "old"}}]}}}
transformed = transformer(payload)
assert transformed["data"]["declaration"]["strategies"][0]["identity"]["provider"] == "provider"
return provider
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = client.fetch_agent_strategy_provider("tenant-1", "org/plugin/provider")
assert request_mock.call_count == 1
assert result.declaration.identity.name == "org/plugin/provider"
assert result.declaration.strategies[0].identity.provider == "org/plugin/provider"
def test_invoke_merges_chunks_and_passes_context(self, mocker):
client = PluginAgentClient()
stream_mock = mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["raw"])
)
merge_mock = mocker.patch("core.plugin.impl.agent.merge_blob_chunks", return_value=["merged"])
context = PluginInvokeContext()
result = client.invoke(
tenant_id="tenant-1",
user_id="user-1",
agent_provider="org/plugin/provider",
agent_strategy="router",
agent_params={"k": "v"},
conversation_id="conv-1",
app_id="app-1",
message_id="msg-1",
context=context,
)
assert result == ["merged"]
assert merge_mock.call_count == 1
payload = stream_mock.call_args.kwargs["data"]
assert payload["data"]["agent_strategy_provider"] == "provider"
assert payload["context"] == context.model_dump()
assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin"

View File

@ -0,0 +1,45 @@
from unittest.mock import MagicMock
import pytest
from core.plugin.impl.asset import PluginAssetManager
class TestPluginAssetManager:
def test_fetch_asset_success(self, mocker):
manager = PluginAssetManager()
response = MagicMock(status_code=200, content=b"asset-bytes")
request_mock = mocker.patch.object(manager, "_request", return_value=response)
result = manager.fetch_asset("tenant-1", "asset-1")
assert result == b"asset-bytes"
request_mock.assert_called_once_with(method="GET", path="plugin/tenant-1/asset/asset-1")
def test_fetch_asset_not_found_raises(self, mocker):
manager = PluginAssetManager()
mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b""))
with pytest.raises(ValueError, match="can not found asset asset-1"):
manager.fetch_asset("tenant-1", "asset-1")
def test_extract_asset_success(self, mocker):
manager = PluginAssetManager()
response = MagicMock(status_code=200, content=b"file-content")
request_mock = mocker.patch.object(manager, "_request", return_value=response)
result = manager.extract_asset("tenant-1", "org/plugin:1", "README.md")
assert result == b"file-content"
request_mock.assert_called_once_with(
method="GET",
path="plugin/tenant-1/extract-asset/",
params={"plugin_unique_identifier": "org/plugin:1", "file_path": "README.md"},
)
def test_extract_asset_not_found_raises(self, mocker):
manager = PluginAssetManager()
mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b""))
with pytest.raises(ValueError, match="can not found asset org/plugin:1, 404"):
manager.extract_asset("tenant-1", "org/plugin:1", "README.md")

View File

@ -0,0 +1,137 @@
import json
import pytest
from core.plugin.endpoint.exc import EndpointSetupFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
from core.plugin.impl.base import BasePluginClient
from core.trigger.errors import (
EventIgnoreError,
TriggerInvokeError,
TriggerPluginInvokeError,
TriggerProviderCredentialValidationError,
)
class _ResponseStub:
def __init__(self, payload):
self._payload = payload
def raise_for_status(self):
return None
def json(self):
return self._payload
class _StreamContext:
def __init__(self, lines):
self._lines = lines
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def iter_lines(self):
return self._lines
class TestBasePluginClientImpl:
def test_inject_trace_headers(self, mocker):
client = BasePluginClient()
mocker.patch("core.plugin.impl.base.dify_config.ENABLE_OTEL", True)
trace_header = "00-abc-xyz-01"
mocker.patch("core.helper.trace_id_helper.generate_traceparent_header", return_value=trace_header)
headers = {}
client._inject_trace_headers(headers)
assert headers["traceparent"] == trace_header
headers_with_existing = {"TraceParent": "exists"}
client._inject_trace_headers(headers_with_existing)
assert headers_with_existing["TraceParent"] == "exists"
def test_stream_request_handles_data_lines_and_dict_payload(self, mocker):
client = BasePluginClient()
stream_mock = mocker.patch(
"core.plugin.impl.base.httpx.stream",
return_value=_StreamContext([b"", b"data: hello", "world"]),
)
result = list(client._stream_request("POST", "plugin/tenant/stream", data={"k": "v"}))
assert result == ["hello", "world"]
assert stream_mock.call_args.kwargs["data"] == {"k": "v"}
def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker):
client = BasePluginClient()
mocker.patch.object(client, "_request", side_effect=RuntimeError("boom"))
with pytest.raises(ValueError, match="Failed to request plugin daemon"):
client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool)
def test_request_with_plugin_daemon_response_applies_transformer(self, mocker):
client = BasePluginClient()
mocker.patch.object(client, "_request", return_value=_ResponseStub({"code": 0, "message": "", "data": True}))
transformed = {}
def transformer(payload):
transformed.update(payload)
return payload
result = client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool, transformer=transformer)
assert result is True
assert transformed == {"code": 0, "message": "", "data": True}
def test_request_with_plugin_daemon_response_stream_malformed_json_error(self, mocker):
client = BasePluginClient()
mocker.patch.object(client, "_stream_request", return_value=iter(['{"error":"bad-line"}']))
with pytest.raises(ValueError, match="bad-line"):
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
def test_request_with_plugin_daemon_response_stream_plugin_daemon_inner_error(self, mocker):
client = BasePluginClient()
mocker.patch.object(
client, "_stream_request", return_value=iter(['{"code":-500,"message":"not-json","data":null}'])
)
with pytest.raises(PluginDaemonInnerError) as exc_info:
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
assert exc_info.value.message == "not-json"
def test_request_with_plugin_daemon_response_stream_plugin_daemon_error(self, mocker):
client = BasePluginClient()
mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":-1,"message":"err","data":null}']))
with pytest.raises(ValueError, match="plugin daemon: err, code: -1"):
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
def test_request_with_plugin_daemon_response_stream_empty_data_error(self, mocker):
client = BasePluginClient()
mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":0,"message":"","data":null}']))
with pytest.raises(ValueError, match="got empty data"):
list(client._request_with_plugin_daemon_response_stream("GET", "p", bool))
@pytest.mark.parametrize(
("error_type", "expected"),
[
(EndpointSetupFailedError.__name__, EndpointSetupFailedError),
(TriggerProviderCredentialValidationError.__name__, TriggerProviderCredentialValidationError),
(TriggerPluginInvokeError.__name__, TriggerPluginInvokeError),
(TriggerInvokeError.__name__, TriggerInvokeError),
(EventIgnoreError.__name__, EventIgnoreError),
],
)
def test_handle_plugin_daemon_error_trigger_branches(self, error_type, expected):
client = BasePluginClient()
message = json.dumps({"error_type": error_type, "message": "m"})
with pytest.raises(expected):
client._handle_plugin_daemon_error("PluginInvokeError", message)

View File

@ -0,0 +1,234 @@
from types import SimpleNamespace
from core.datasource.entities.datasource_entities import (
GetOnlineDocumentPageContentRequest,
OnlineDriveBrowseFilesRequest,
OnlineDriveDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
def _datasource_provider(name: str = "provider") -> SimpleNamespace:
return SimpleNamespace(
plugin_id="org/plugin",
declaration=SimpleNamespace(
identity=SimpleNamespace(name=name),
datasources=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
),
)
class TestPluginDatasourceManager:
def test_fetch_datasource_providers(self, mocker):
manager = PluginDatasourceManager()
provider = _datasource_provider("remote")
repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider")
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": [
{
"declaration": {
"identity": {"name": "remote"},
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}],
}
}
]
}
transformed = transformer(payload)
assert transformed["data"][0]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True}
return [provider]
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = manager.fetch_datasource_providers("tenant-1")
assert request_mock.call_count == 1
assert len(result) == 2
assert result[0].plugin_id == "langgenius/file"
assert result[1].declaration.identity.name == "org/plugin/remote"
assert result[1].declaration.datasources[0].identity.provider == "org/plugin/remote"
repack.assert_called_once_with(tenant_id="tenant-1", provider=provider)
def test_fetch_installed_datasource_providers(self, mocker):
manager = PluginDatasourceManager()
provider = _datasource_provider("remote")
repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider")
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": [
{
"declaration": {
"identity": {"name": "remote"},
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}],
}
}
]
}
transformer(payload)
return [provider]
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = manager.fetch_installed_datasource_providers("tenant-1")
assert request_mock.call_count == 1
assert len(result) == 1
assert result[0].declaration.identity.name == "org/plugin/remote"
assert result[0].declaration.datasources[0].identity.provider == "org/plugin/remote"
repack.assert_called_once_with(tenant_id="tenant-1", provider=provider)
def test_fetch_datasource_provider_local_and_remote(self, mocker):
manager = PluginDatasourceManager()
local = manager.fetch_datasource_provider("tenant-1", "langgenius/file/file")
assert local.plugin_id == "langgenius/file"
remote = _datasource_provider("provider")
mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True})
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": {
"declaration": {
"datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]
}
}
}
transformed = transformer(payload)
assert transformed["data"]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True}
return remote
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = manager.fetch_datasource_provider("tenant-1", "org/plugin/provider")
assert request_mock.call_count == 1
assert result.declaration.identity.name == "org/plugin/provider"
assert result.declaration.datasources[0].identity.provider == "org/plugin/provider"
def test_get_website_crawl_streaming(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter(["crawl"])
assert list(
manager.get_website_crawl(
"tenant-1",
"user-1",
"org/plugin/provider",
"crawl",
{"k": "v"},
{"url": "https://example.com"},
"website",
)
) == ["crawl"]
assert stream_mock.call_count == 1
def test_get_online_document_pages_streaming(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter(["pages"])
assert list(
manager.get_online_document_pages(
"tenant-1",
"user-1",
"org/plugin/provider",
"docs",
{"k": "v"},
{"workspace": "w1"},
"online_document",
)
) == ["pages"]
assert stream_mock.call_count == 1
def test_get_online_document_page_content_streaming(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter(["content"])
assert list(
manager.get_online_document_page_content(
"tenant-1",
"user-1",
"org/plugin/provider",
"docs",
{"k": "v"},
GetOnlineDocumentPageContentRequest(workspace_id="w", page_id="p", type="doc"),
"online_document",
)
) == ["content"]
assert stream_mock.call_count == 1
def test_online_drive_browse_files_streaming(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter(["browse"])
assert list(
manager.online_drive_browse_files(
"tenant-1",
"user-1",
"org/plugin/provider",
"drive",
{"k": "v"},
OnlineDriveBrowseFilesRequest(prefix="/"),
"online_drive",
)
) == ["browse"]
assert stream_mock.call_count == 1
def test_online_drive_download_file_streaming(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter(["download"])
assert list(
manager.online_drive_download_file(
"tenant-1",
"user-1",
"org/plugin/provider",
"drive",
{"k": "v"},
OnlineDriveDownloadFileRequest(id="file-1"),
"online_drive",
)
) == ["download"]
assert stream_mock.call_count == 1
def test_validate_provider_credentials_returns_true_when_stream_yields_result(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter([SimpleNamespace(result=True)])
assert manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is True
def test_validate_provider_credentials_returns_false_when_stream_empty(self, mocker):
manager = PluginDatasourceManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter([])
assert (
manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is False
)
def test_local_file_provider_template(self):
manager = PluginDatasourceManager()
payload = manager._get_local_file_datasource_provider()
assert payload["plugin_id"] == "langgenius/file"
assert payload["provider"] == "file"
assert payload["declaration"]["provider_type"] == "local_file"

View File

@ -0,0 +1,21 @@
from types import SimpleNamespace
from core.plugin.impl.debugging import PluginDebuggingClient
class TestPluginDebuggingClient:
def test_get_debugging_key(self, mocker):
client = PluginDebuggingClient()
request_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response",
return_value=SimpleNamespace(key="debug-key"),
)
result = client.get_debugging_key("tenant-1")
assert result == "debug-key"
request_mock.assert_called_once()
args = request_mock.call_args.args
assert args[0] == "POST"
assert args[1] == "plugin/tenant-1/debugging/key"

View File

@ -0,0 +1,71 @@
import pytest
from core.plugin.impl.endpoint import PluginEndpointClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class TestPluginEndpointClientImpl:
def test_create_endpoint(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
result = client.create_endpoint("tenant-1", "user-1", "org/plugin:1", "endpoint-a", {"k": "v"})
assert result is True
assert request_mock.call_count == 1
args = request_mock.call_args.args
kwargs = request_mock.call_args.kwargs
assert args[:3] == ("POST", "plugin/tenant-1/endpoint/setup", bool)
assert kwargs["data"]["plugin_unique_identifier"] == "org/plugin:1"
def test_list_endpoints(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"])
result = client.list_endpoints("tenant-1", "user-1", 2, 20)
assert result == ["endpoint"]
assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list"
assert request_mock.call_args.kwargs["params"] == {"page": 2, "page_size": 20}
def test_list_endpoints_for_single_plugin(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"])
result = client.list_endpoints_for_single_plugin("tenant-1", "user-1", "org/plugin", 1, 10)
assert result == ["endpoint"]
assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list/plugin"
assert request_mock.call_args.kwargs["params"] == {"plugin_id": "org/plugin", "page": 1, "page_size": 10}
def test_update_endpoint(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
result = client.update_endpoint("tenant-1", "user-1", "endpoint-1", "renamed", {"x": 1})
assert result is True
assert request_mock.call_args.args[:3] == ("POST", "plugin/tenant-1/endpoint/update", bool)
def test_enable_and_disable_endpoint(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True)
assert client.enable_endpoint("tenant-1", "user-1", "endpoint-1") is True
assert client.disable_endpoint("tenant-1", "user-1", "endpoint-1") is True
calls = request_mock.call_args_list
assert calls[0].args[1] == "plugin/tenant-1/endpoint/enable"
assert calls[1].args[1] == "plugin/tenant-1/endpoint/disable"
def test_delete_endpoint_idempotent_and_re_raise(self, mocker):
client = PluginEndpointClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response")
request_mock.side_effect = PluginDaemonInternalServerError("record not found")
assert client.delete_endpoint("tenant-1", "user-1", "endpoint-1") is True
request_mock.side_effect = PluginDaemonInternalServerError("permission denied")
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
client.delete_endpoint("tenant-1", "user-1", "endpoint-1")
assert "permission denied" in exc_info.value.description

View File

@ -0,0 +1,41 @@
import json
from core.plugin.impl import exc as exc_module
from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError
class TestPluginImplExceptions:
def test_plugin_daemon_error_str_contains_request_id(self, mocker):
mocker.patch("core.plugin.impl.exc.get_request_id", return_value="req-123")
error = PluginDaemonError("bad")
assert str(error) == "req_id: req-123 PluginDaemonError: bad"
def test_plugin_invoke_error_with_json_payload(self):
err = PluginInvokeError(json.dumps({"error_type": "RateLimit", "message": "too many"}))
assert err.get_error_type() == "RateLimit"
assert err.get_error_message() == "too many"
friendly = err.to_user_friendly_error("test-plugin")
assert "test-plugin" in friendly
assert "RateLimit" in friendly
assert "too many" in friendly
def test_plugin_invoke_error_invalid_json_and_fallback(self, mocker):
err = PluginInvokeError("plain text")
assert err._get_error_object() == {}
assert err.get_error_type() == "unknown"
assert err.get_error_message() == "unknown"
mocker.patch.object(PluginInvokeError, "_get_error_object", side_effect=RuntimeError("boom"))
err2 = PluginInvokeError("plain text")
assert err2.get_error_message() == "plain text"
def test_plugin_invoke_error_get_error_object_handles_adapter_exception(self, mocker):
adapter = mocker.patch.object(exc_module, "TypeAdapter")
adapter.return_value.validate_json.side_effect = RuntimeError("invalid")
err = PluginInvokeError("not-json")
assert err._get_error_object() == {}

View File

@ -0,0 +1,490 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
from core.plugin.impl.model import PluginModelClient
class TestPluginModelClient:
def test_fetch_model_providers(self, mocker):
client = PluginModelClient()
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["provider-a"])
result = client.fetch_model_providers("tenant-1")
assert result == ["provider-a"]
assert request_mock.call_args.args[:2] == (
"GET",
"plugin/tenant-1/management/models",
)
assert request_mock.call_args.kwargs["params"] == {"page": 1, "page_size": 256}
def test_get_model_schema(self, mocker):
client = PluginModelClient()
schema = SimpleNamespace(name="schema")
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(model_schema=schema)]),
)
result = client.get_model_schema(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model_type="llm",
model="gpt-test",
credentials={"api_key": "key"},
)
assert result is schema
assert stream_mock.call_args.args[:2] == ("POST", "plugin/tenant-1/dispatch/model/schema")
def test_get_model_schema_empty_stream_returns_none(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
result = client.get_model_schema("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {})
assert result is None
def test_validate_provider_credentials(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result=True, credentials={"api_key": "new"})]),
)
credentials = {"api_key": "old"}
result = client.validate_provider_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
credentials=credentials,
)
assert result is True
assert credentials["api_key"] == "new"
assert stream_mock.call_args.args[:2] == (
"POST",
"plugin/tenant-1/dispatch/model/validate_provider_credentials",
)
def test_validate_provider_credentials_without_dict_update(self, mocker):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result=False, credentials="not-a-dict")]),
)
credentials = {"api_key": "same"}
result = client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", credentials)
assert result is False
assert credentials == {"api_key": "same"}
def test_validate_provider_credentials_empty_returns_false(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", {}) is False
def test_validate_model_credentials(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result=True, credentials={"token": "rotated"})]),
)
credentials = {"token": "old"}
result = client.validate_model_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model_type="llm",
model="gpt-test",
credentials=credentials,
)
assert result is True
assert credentials["token"] == "rotated"
assert stream_mock.call_args.args[:2] == (
"POST",
"plugin/tenant-1/dispatch/model/validate_model_credentials",
)
def test_validate_model_credentials_empty_returns_false(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
assert (
client.validate_model_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {})
is False
)
def test_invoke_llm(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
)
result = list(
client.invoke_llm(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
prompt_messages=[],
model_parameters={"temperature": 0.1},
tools=[],
stop=["STOP"],
stream=False,
)
)
assert result == ["chunk-1"]
call_kwargs = stream_mock.call_args.kwargs
assert call_kwargs["path"] == "plugin/tenant-1/dispatch/llm/invoke"
assert call_kwargs["data"]["data"]["stream"] is False
assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1}
def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker):
client = PluginModelClient()
def _boom():
raise PluginDaemonInnerError(code=-500, message="invoke failed")
yield # pragma: no cover
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom())
with pytest.raises(ValueError, match="invoke failed-500"):
list(
client.invoke_llm(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={},
prompt_messages=[],
)
)
def test_get_llm_num_tokens(self, mocker):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(num_tokens=42)]),
)
result = client.get_llm_num_tokens(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model_type="llm",
model="gpt-test",
credentials={},
prompt_messages=[],
tools=[],
)
assert result == 42
def test_get_llm_num_tokens_empty_returns_zero(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
assert (
client.get_llm_num_tokens("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}, [])
== 0
)
def test_invoke_text_embedding(self, mocker):
client = PluginModelClient()
embedding_result = SimpleNamespace(data=[[0.1, 0.2]])
mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result])
)
result = client.invoke_text_embedding(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="embedding-a",
credentials={},
texts=["hello"],
input_type="search_document",
)
assert result is embedding_result
def test_invoke_text_embedding_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke text embedding"):
client.invoke_text_embedding(
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["hello"], "x"
)
def test_invoke_multimodal_embedding(self, mocker):
client = PluginModelClient()
embedding_result = SimpleNamespace(data=[[0.3, 0.4]])
mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result])
)
result = client.invoke_multimodal_embedding(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="embedding-a",
credentials={},
documents=[{"type": "image", "value": "abc"}],
input_type="search_document",
)
assert result is embedding_result
def test_invoke_multimodal_embedding_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke file embedding"):
client.invoke_multimodal_embedding(
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, [{"type": "image"}], "x"
)
def test_get_text_embedding_num_tokens(self, mocker):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(num_tokens=[1, 2, 3])]),
)
assert client.get_text_embedding_num_tokens(
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"]
) == [
1,
2,
3,
]
def test_get_text_embedding_num_tokens_empty_returns_list(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
assert (
client.get_text_embedding_num_tokens(
"tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"]
)
== []
)
def test_invoke_rerank(self, mocker):
client = PluginModelClient()
rerank_result = SimpleNamespace(scores=[0.9])
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result]))
result = client.invoke_rerank(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="rerank-a",
credentials={},
query="q",
docs=["doc-1"],
score_threshold=0.2,
top_n=5,
)
assert result is rerank_result
def test_invoke_rerank_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke rerank"):
client.invoke_rerank("tenant-1", "user-1", "org/plugin:1", "provider-a", "rerank-a", {}, "q", ["doc-1"])
def test_invoke_multimodal_rerank(self, mocker):
client = PluginModelClient()
rerank_result = SimpleNamespace(scores=[0.8])
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result]))
result = client.invoke_multimodal_rerank(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="rerank-a",
credentials={},
query={"type": "text", "value": "q"},
docs=[{"type": "image", "value": "doc"}],
score_threshold=0.1,
top_n=3,
)
assert result is rerank_result
def test_invoke_multimodal_rerank_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke multimodal rerank"):
client.invoke_multimodal_rerank(
"tenant-1",
"user-1",
"org/plugin:1",
"provider-a",
"rerank-a",
{},
{"type": "text"},
[{"type": "image"}],
)
def test_invoke_tts(self, mocker):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result="68656c6c6f"), SimpleNamespace(result="21")]),
)
result = list(
client.invoke_tts(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="tts-a",
credentials={},
content_text="hello",
voice="alloy",
)
)
assert result == [b"hello", b"!"]
def test_invoke_tts_wraps_plugin_daemon_inner_error(self, mocker):
client = PluginModelClient()
def _boom():
raise PluginDaemonInnerError(code=-400, message="tts error")
yield # pragma: no cover
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom())
with pytest.raises(ValueError, match="tts error-400"):
list(client.invoke_tts("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}, "hello", "alloy"))
def test_get_tts_model_voices(self, mocker):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter(
[
SimpleNamespace(
voices=[
SimpleNamespace(name="Alloy", value="alloy"),
SimpleNamespace(name="Echo", value="echo"),
]
)
]
),
)
result = client.get_tts_model_voices(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="tts-a",
credentials={},
language="en",
)
assert result == [{"name": "Alloy", "value": "alloy"}, {"name": "Echo", "value": "echo"}]
def test_get_tts_model_voices_empty_returns_list(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
assert client.get_tts_model_voices("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}) == []
def test_invoke_speech_to_text(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result="transcribed text")]),
)
result = client.invoke_speech_to_text(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="stt-a",
credentials={},
file=io.BytesIO(b"abc"),
)
assert result == "transcribed text"
assert stream_mock.call_args.kwargs["data"]["data"]["file"] == "616263"
def test_invoke_speech_to_text_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke speech to text"):
client.invoke_speech_to_text(
"tenant-1", "user-1", "org/plugin:1", "provider-a", "stt-a", {}, io.BytesIO(b"abc")
)
def test_invoke_moderation(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(result=True)]),
)
result = client.invoke_moderation(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="moderation-a",
credentials={},
text="safe text",
)
assert result is True
assert stream_mock.call_args.kwargs["path"] == "plugin/tenant-1/dispatch/moderation/invoke"
def test_invoke_moderation_empty_raises(self, mocker):
client = PluginModelClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Failed to invoke moderation"):
client.invoke_moderation("tenant-1", "user-1", "org/plugin:1", "provider-a", "moderation-a", {}, "unsafe")

View File

@ -0,0 +1,147 @@
from io import BytesIO
from types import SimpleNamespace
import pytest
from werkzeug import Request
from core.plugin.impl.oauth import OAuthHandler
def _build_request(body: bytes = b"payload") -> Request:
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/oauth/callback",
"QUERY_STRING": "code=123",
"SERVER_NAME": "localhost",
"SERVER_PORT": "80",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"HTTP_HOST": "localhost",
"SERVER_PROTOCOL": "HTTP/1.1",
"HTTP_X_TEST": "yes",
}
return Request(environ)
class TestOAuthHandler:
def test_get_authorization_url(self, mocker):
handler = OAuthHandler()
stream_mock = mocker.patch.object(
handler,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(authorization_url="https://auth.example.com")]),
)
response = handler.get_authorization_url(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={"client_id": "id"},
)
assert response.authorization_url == "https://auth.example.com"
assert stream_mock.call_count == 1
def test_get_authorization_url_no_response_raises(self, mocker):
handler = OAuthHandler()
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Error getting authorization URL"):
handler.get_authorization_url(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={},
)
def test_get_credentials(self, mocker):
handler = OAuthHandler()
captured_data = {}
def fake_stream(*args, **kwargs):
captured_data.update(kwargs["data"])
return iter([SimpleNamespace(credentials={"token": "abc"}, metadata={}, expires_at=1)])
stream_mock = mocker.patch.object(
handler, "_request_with_plugin_daemon_response_stream", side_effect=fake_stream
)
response = handler.get_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={"client_id": "id"},
request=_build_request(),
)
assert response.credentials == {"token": "abc"}
assert "raw_http_request" in captured_data["data"]
assert stream_mock.call_count == 1
def test_get_credentials_no_response_raises(self, mocker):
handler = OAuthHandler()
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Error getting credentials"):
handler.get_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={},
request=_build_request(),
)
def test_refresh_credentials(self, mocker):
handler = OAuthHandler()
stream_mock = mocker.patch.object(
handler,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(credentials={"token": "new"}, metadata={}, expires_at=1)]),
)
response = handler.refresh_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={"client_id": "id"},
credentials={"refresh_token": "r"},
)
assert response.credentials == {"token": "new"}
assert stream_mock.call_count == 1
def test_refresh_credentials_no_response_raises(self, mocker):
handler = OAuthHandler()
mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="Error refreshing credentials"):
handler.refresh_credentials(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin",
provider="provider",
redirect_uri="https://dify.example.com/callback",
system_credentials={},
credentials={},
)
def test_convert_request_to_raw_data(self):
handler = OAuthHandler()
request = _build_request(b"body-data")
raw = handler._convert_request_to_raw_data(request)
assert raw.startswith(b"POST /oauth/callback?code=123 HTTP/1.1\r\n")
assert b"X-Test: yes\r\n" in raw
assert raw.endswith(b"body-data")

View File

@ -0,0 +1,121 @@
from types import SimpleNamespace
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.tool import PluginToolManager
def _tool_provider(name: str = "provider") -> SimpleNamespace:
return SimpleNamespace(
plugin_id="org/plugin",
declaration=SimpleNamespace(
identity=SimpleNamespace(name=name),
tools=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
),
)
class TestPluginToolManager:
def test_fetch_tool_providers(self, mocker):
manager = PluginToolManager()
provider = _tool_provider("remote")
mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True})
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": [
{
"declaration": {
"identity": {"name": "remote"},
"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}],
}
}
]
}
transformed = transformer(payload)
assert transformed["data"][0]["declaration"]["tools"][0]["output_schema"] == {"resolved": True}
return [provider]
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = manager.fetch_tool_providers("tenant-1")
assert request_mock.call_count == 1
assert result[0].declaration.identity.name == "org/plugin/remote"
assert result[0].declaration.tools[0].identity.provider == "org/plugin/remote"
def test_fetch_tool_provider(self, mocker):
manager = PluginToolManager()
provider = _tool_provider("provider")
mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True})
def fake_request(method, path, type_, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": {
"declaration": {"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]}
}
}
transformed = transformer(payload)
assert transformed["data"]["declaration"]["tools"][0]["output_schema"] == {"resolved": True}
return provider
request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = manager.fetch_tool_provider("tenant-1", "org/plugin/provider")
assert request_mock.call_count == 1
assert result.declaration.identity.name == "org/plugin/provider"
assert result.declaration.tools[0].identity.provider == "org/plugin/provider"
def test_invoke_merges_chunks(self, mocker):
manager = PluginToolManager()
stream_mock = mocker.patch.object(
manager, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk"])
)
merge_mock = mocker.patch("core.plugin.impl.tool.merge_blob_chunks", return_value=["merged"])
result = manager.invoke(
tenant_id="tenant-1",
user_id="user-1",
tool_provider="org/plugin/provider",
tool_name="search",
credentials={"api_key": "k"},
credential_type=CredentialType.API_KEY,
tool_parameters={"q": "python"},
conversation_id="conv-1",
app_id="app-1",
message_id="msg-1",
)
assert result == ["merged"]
assert merge_mock.call_count == 1
assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin"
def test_validate_credentials_paths(self, mocker):
manager = PluginToolManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter([SimpleNamespace(result=True)])
assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
stream_mock.return_value = iter([])
assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False
stream_mock.return_value = iter([SimpleNamespace(result=True)])
assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
stream_mock.return_value = iter([])
assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False
def test_get_runtime_parameters_paths(self, mocker):
manager = PluginToolManager()
stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter([SimpleNamespace(parameters=[{"name": "p"}])])
params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search")
assert params == [{"name": "p"}]
stream_mock.return_value = iter([])
params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search")
assert params == []

View File

@ -0,0 +1,226 @@
from io import BytesIO
from types import SimpleNamespace
import pytest
from werkzeug import Request
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.trigger import PluginTriggerClient
from core.trigger.entities.entities import Subscription
from models.provider_ids import TriggerProviderID
def _request() -> Request:
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/events",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "80",
"wsgi.input": BytesIO(b"payload"),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": "7",
"HTTP_HOST": "localhost",
}
return Request(environ)
def _subscription() -> Subscription:
return Subscription(expires_at=123, endpoint="https://example.com/hook", parameters={"a": 1}, properties={"p": 1})
def _trigger_provider(name: str = "provider") -> SimpleNamespace:
return SimpleNamespace(
plugin_id="org/plugin",
declaration=SimpleNamespace(
identity=SimpleNamespace(name=name),
events=[SimpleNamespace(identity=SimpleNamespace(provider=""))],
),
)
def _subscription_call_kwargs(method_name: str) -> dict:
if method_name == "subscribe":
return {
"tenant_id": "tenant-1",
"user_id": "user-1",
"provider": "org/plugin/provider",
"credentials": {"token": "x"},
"credential_type": CredentialType.API_KEY,
"endpoint": "https://example.com/hook",
"parameters": {"k": "v"},
}
return {
"tenant_id": "tenant-1",
"user_id": "user-1",
"provider": "org/plugin/provider",
"subscription": _subscription(),
"credentials": {"token": "x"},
"credential_type": CredentialType.API_KEY,
}
class TestPluginTriggerClient:
def test_fetch_trigger_providers(self, mocker):
client = PluginTriggerClient()
provider = _trigger_provider("remote")
def fake_request(*args, **kwargs):
transformer = kwargs["transformer"]
payload = {
"data": [
{
"plugin_id": "org/plugin",
"provider": "remote",
"declaration": {"events": [{"identity": {"provider": "old"}}]},
}
]
}
transformed = transformer(payload)
assert transformed["data"][0]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/remote"
return [provider]
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = client.fetch_trigger_providers("tenant-1")
assert request_mock.call_count == 1
assert result[0].declaration.identity.name == "org/plugin/remote"
assert result[0].declaration.events[0].identity.provider == "org/plugin/remote"
def test_fetch_trigger_provider(self, mocker):
client = PluginTriggerClient()
provider = _trigger_provider("provider")
def fake_request(*args, **kwargs):
transformer = kwargs["transformer"]
payload = {"data": {"declaration": {"events": [{"identity": {"provider": "old"}}]}}}
transformed = transformer(payload)
assert transformed["data"]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/provider"
return provider
request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request)
result = client.fetch_trigger_provider("tenant-1", TriggerProviderID("org/plugin/provider"))
assert request_mock.call_count == 1
assert result.declaration.identity.name == "org/plugin/provider"
assert result.declaration.events[0].identity.provider == "org/plugin/provider"
def test_invoke_trigger_event(self, mocker):
client = PluginTriggerClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(variables={"ok": True}, cancelled=False)]),
)
result = client.invoke_trigger_event(
tenant_id="tenant-1",
user_id="user-1",
provider="org/plugin/provider",
event_name="created",
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
request=_request(),
parameters={"k": "v"},
subscription=_subscription(),
payload={"payload": 1},
)
assert result.variables == {"ok": True}
assert stream_mock.call_count == 1
def test_invoke_trigger_event_no_response_raises(self, mocker):
client = PluginTriggerClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
with pytest.raises(ValueError, match="No response received from plugin daemon for invoke trigger"):
client.invoke_trigger_event(
tenant_id="tenant-1",
user_id="user-1",
provider="org/plugin/provider",
event_name="created",
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
request=_request(),
parameters={"k": "v"},
subscription=_subscription(),
payload={"payload": 1},
)
def test_validate_provider_credentials(self, mocker):
client = PluginTriggerClient()
stream_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response_stream")
stream_mock.return_value = iter([SimpleNamespace(result=True)])
assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True
stream_mock.return_value = iter([])
with pytest.raises(
ValueError, match="No response received from plugin daemon for validate provider credentials"
):
client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"})
def test_dispatch_event(self, mocker):
client = PluginTriggerClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(user_id="u", events=["e"])]),
)
result = client.dispatch_event(
tenant_id="tenant-1",
provider="org/plugin/provider",
subscription={"id": "sub"},
request=_request(),
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result.user_id == "u"
assert stream_mock.call_count == 1
stream_mock.return_value = iter([])
with pytest.raises(ValueError, match="No response received from plugin daemon for dispatch event"):
client.dispatch_event(
tenant_id="tenant-1",
provider="org/plugin/provider",
subscription={"id": "sub"},
request=_request(),
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
@pytest.mark.parametrize("method_name", ["subscribe", "unsubscribe", "refresh"])
def test_subscription_operations_success(self, mocker, method_name):
client = PluginTriggerClient()
stream_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response_stream",
return_value=iter([SimpleNamespace(subscription={"id": "sub"})]),
)
method = getattr(client, method_name)
result = method(**_subscription_call_kwargs(method_name))
assert result.subscription == {"id": "sub"}
assert stream_mock.call_count == 1
@pytest.mark.parametrize(
("method_name", "expected"),
[
("subscribe", "No response received from plugin daemon for subscribe"),
("unsubscribe", "No response received from plugin daemon for unsubscribe"),
("refresh", "No response received from plugin daemon for refresh"),
],
)
def test_subscription_operations_no_response(self, mocker, method_name, expected):
client = PluginTriggerClient()
mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([]))
method = getattr(client, method_name)
with pytest.raises(ValueError, match=expected):
method(**_subscription_call_kwargs(method_name))

View File

@ -1,72 +1,359 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from models.model import AppMode
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
class _Chunk(BaseModel):
value: int
app = MagicMock()
app.mode = AppMode.ADVANCED_CHAT
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
class TestBaseBackwardsInvocation:
def test_convert_to_event_stream_with_generator_and_error(self):
def _stream():
yield _Chunk(value=1)
yield {"x": 2}
yield "ignored"
raise RuntimeError("boom")
chunks = list(BaseBackwardsInvocation.convert_to_event_stream(_stream()))
assert len(chunks) == 3
first = json.loads(chunks[0].decode())
second = json.loads(chunks[1].decode())
error = json.loads(chunks[2].decode())
assert first["data"]["value"] == 1
assert second["data"]["x"] == 2
assert error["error"] == "boom"
def test_convert_to_event_stream_with_non_generator(self):
chunks = list(BaseBackwardsInvocation.convert_to_event_stream({"ok": True}))
payload = json.loads(chunks[0].decode())
assert payload["data"] == {"ok": True}
assert payload["error"] == ""
class TestPluginAppBackwardsInvocation:
def test_fetch_app_info_workflow_path(self, mocker):
workflow = MagicMock()
workflow.features_dict = {"feature": "v"}
workflow.user_input_form.return_value = [{"name": "foo"}]
app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mapper = mocker.patch(
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
return_value={"mapped": True},
)
result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1")
assert result == {"data": {"mapped": True}}
mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}])
def test_fetch_app_info_model_config_path(self, mocker):
model_config = MagicMock()
model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"}
app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch(
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
return_value={"mapped": True},
)
result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1")
assert result["data"] == {"mapped": True}
@pytest.mark.parametrize(
("mode", "route_method"),
[
(AppMode.CHAT, "invoke_chat_app"),
(AppMode.ADVANCED_CHAT, "invoke_chat_app"),
(AppMode.AGENT_CHAT, "invoke_chat_app"),
(AppMode.WORKFLOW, "invoke_workflow_app"),
(AppMode.COMPLETION, "invoke_completion_app"),
],
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
def test_invoke_app_routes_by_mode(self, mocker, mode, route_method):
app = MagicMock(mode=mode)
user = MagicMock()
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user)
route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True})
result = PluginAppBackwardsInvocation.invoke_app(
app_id="app",
user_id="user",
tenant_id="tenant",
conversation_id=None,
query="hello",
stream=False,
inputs={"x": 1},
files=[],
)
assert result == {"routed": True}
assert route.call_count == 1
def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker):
app = MagicMock(mode=AppMode.WORKFLOW)
end_user = MagicMock()
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
get_or_create = mocker.patch(
"core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user",
return_value=end_user,
)
route = mocker.patch.object(PluginAppBackwardsInvocation, "invoke_workflow_app", return_value={"ok": True})
result = PluginAppBackwardsInvocation.invoke_app(
app_id="app",
user_id="",
tenant_id="tenant",
conversation_id="",
query=None,
stream=True,
inputs={},
files=[],
)
assert result == {"ok": True}
get_or_create.assert_called_once_with(app)
assert route.call_args.args[1] is end_user
def test_invoke_app_missing_query_for_chat_raises(self, mocker):
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT))
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
with pytest.raises(ValueError, match="missing query"):
PluginAppBackwardsInvocation.invoke_app(
app_id="app",
user_id="user",
tenant_id="tenant",
conversation_id=None,
query="",
stream=False,
inputs={},
files=[],
)
def test_invoke_app_unexpected_mode_raises(self, mocker):
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode="other"))
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_app(
app_id="app",
user_id="user",
tenant_id="tenant",
conversation_id=None,
query="q",
stream=False,
inputs={},
files=[],
)
@pytest.mark.parametrize(
("mode", "generator_path"),
[
(AppMode.AGENT_CHAT, "core.plugin.backwards_invocation.app.AgentChatAppGenerator.generate"),
(AppMode.CHAT, "core.plugin.backwards_invocation.app.ChatAppGenerator.generate"),
],
)
def test_invoke_chat_app_agent_and_chat(self, mocker, mode, generator_path):
app = MagicMock(mode=mode, workflow=None)
spy = mocker.patch(generator_path, return_value={"result": "ok"})
result = PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={"k": "v"},
files=[],
)
result = PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
assert result == {"result": "ok"}
assert spy.call_count == 1
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(self, mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
def test_invoke_workflow_app_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.ADVANCED_CHAT
app.workflow = workflow
app = MagicMock()
app.mode = AppMode.WORKFLOW
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
)
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={"k": "v"},
files=[],
)
result = PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self):
app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None)
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={},
files=[],
)
def test_invoke_chat_app_unexpected_mode_raises(self):
app = MagicMock(mode="invalid")
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={},
files=[],
)
def test_invoke_workflow_app_injects_pause_state_config(self, mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.WORKFLOW
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_workflow_app_without_workflow_raises(self):
app = MagicMock(mode=AppMode.WORKFLOW, workflow=None)
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
stream=False,
inputs={},
files=[],
)
def test_invoke_completion_app(self, mocker):
spy = mocker.patch(
"core.plugin.backwards_invocation.app.CompletionAppGenerator.generate", return_value={"ok": 1}
)
app = MagicMock(mode=AppMode.COMPLETION)
result = PluginAppBackwardsInvocation.invoke_completion_app(app, MagicMock(), False, {"x": 1}, [])
assert result == {"ok": 1}
assert spy.call_count == 1
def test_get_user_returns_end_user(self, mocker):
session = MagicMock()
session.scalar.side_effect = [MagicMock(id="end-user")]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
user = PluginAppBackwardsInvocation._get_user("uid")
assert user.id == "end-user"
def test_get_user_falls_back_to_account_user(self, mocker):
session = MagicMock()
session.scalar.side_effect = [None, MagicMock(id="account-user")]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
user = PluginAppBackwardsInvocation._get_user("uid")
assert user.id == "account-user"
def test_get_user_raises_when_user_not_found(self, mocker):
session = MagicMock()
session.scalar.side_effect = [None, None]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
with pytest.raises(ValueError, match="user not found"):
PluginAppBackwardsInvocation._get_user("uid")
def test_get_app_returns_app(self, mocker):
query_chain = MagicMock()
query_chain.where.return_value = query_chain
app_obj = MagicMock(id="app")
query_chain.first.return_value = app_obj
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
def test_get_app_raises_when_missing(self, mocker):
query_chain = MagicMock()
query_chain.where.return_value = query_chain
query_chain.first.return_value = None
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")
def test_get_app_raises_when_query_fails(self, mocker):
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")

View File

@ -0,0 +1,347 @@
import binascii
import datetime
from enum import StrEnum
import pytest
from flask import Response
from pydantic import ValidationError
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeSpeech2Text,
TriggerDispatchResponse,
TriggerInvokeEventResponse,
)
from core.plugin.utils.http_parser import serialize_response
from core.tools.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
class TestEndpointEntity:
def test_endpoint_entity_with_instance_renders_url(self, mocker):
mocker.patch("core.plugin.entities.endpoint.dify_config.ENDPOINT_URL_TEMPLATE", "https://dify.test/{hook_id}")
now = datetime.datetime.now(datetime.UTC)
entity = EndpointEntityWithInstance.model_validate(
{
"id": "ep-1",
"created_at": now,
"updated_at": now,
"settings": {},
"tenant_id": "tenant",
"plugin_id": "org/plugin",
"expired_at": now,
"name": "my-endpoint",
"enabled": True,
"hook_id": "hook-123",
}
)
assert entity.url == "https://dify.test/hook-123"
def test_endpoint_entity_with_instance_keeps_existing_url(self):
now = datetime.datetime.now(datetime.UTC)
entity = EndpointEntityWithInstance.model_validate(
{
"id": "ep-1",
"created_at": now,
"updated_at": now,
"settings": {},
"tenant_id": "tenant",
"plugin_id": "org/plugin",
"expired_at": now,
"name": "my-endpoint",
"enabled": True,
"hook_id": "hook-123",
"url": "https://preset.test/hook-123",
}
)
assert entity.url == "https://preset.test/hook-123"
class TestMarketplaceEntities:
def test_marketplace_declaration_strips_empty_optional_fields(self):
declaration = MarketplacePluginDeclaration.model_validate(
{
"name": "plugin",
"org": "org",
"plugin_id": "org/plugin",
"icon": "icon.png",
"label": {"en_US": "Plugin"},
"brief": {"en_US": "Brief"},
"resource": {"memory": 256},
"endpoint": {},
"model": {},
"tool": {},
"latest_version": "1.0.0",
"latest_package_identifier": "org/plugin@1.0.0",
"status": "active",
"deprecated_reason": "",
"alternative_plugin_id": "",
}
)
assert declaration.endpoint is None
assert declaration.model is None
assert declaration.tool is None
def test_marketplace_snapshot_computed_plugin_id(self):
snapshot = MarketplacePluginSnapshot(
org="langgenius",
name="search",
latest_version="1.0.0",
latest_package_identifier="langgenius/search@1.0.0",
latest_package_url="https://example.com/pkg",
)
assert snapshot.plugin_id == "langgenius/search"
class TestPluginParameterEntities:
def _label(self) -> I18nObject:
return I18nObject(en_US="label")
def test_parameter_option_value_casts_to_string(self):
option = PluginParameterOption(value=123, label=self._label())
assert option.value == "123"
def test_plugin_parameter_options_non_list_defaults_to_empty(self):
parameter = PluginParameter(name="p", label=self._label(), options="invalid") # type: ignore[arg-type]
assert parameter.options == []
@pytest.mark.parametrize(
("parameter_type", "expected"),
[
(PluginParameterType.SECRET_INPUT, "string"),
(PluginParameterType.SELECT, "string"),
(PluginParameterType.CHECKBOX, "string"),
(PluginParameterType.NUMBER, PluginParameterType.NUMBER.value),
],
)
def test_as_normal_type(self, parameter_type, expected):
assert as_normal_type(parameter_type) == expected
@pytest.mark.parametrize(
("value", "expected"),
[(None, ""), (1, "1"), ("abc", "abc")],
)
def test_cast_parameter_value_string_like(self, value, expected):
assert cast_parameter_value(PluginParameterType.STRING, value) == expected
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
("true", True),
("yes", True),
("1", True),
("false", False),
("0", False),
("random", True),
(1, True),
(0, False),
],
)
def test_cast_parameter_value_boolean(self, value, expected):
assert cast_parameter_value(PluginParameterType.BOOLEAN, value) is expected
@pytest.mark.parametrize(
("value", "expected"),
[
(1, 1),
(1.5, 1.5),
("2", 2),
("2.5", 2.5),
],
)
def test_cast_parameter_value_number(self, value, expected):
assert cast_parameter_value(PluginParameterType.NUMBER, value) == expected
def test_cast_parameter_value_file_and_files(self):
assert cast_parameter_value(PluginParameterType.FILES, "f1") == ["f1"]
assert cast_parameter_value(PluginParameterType.SYSTEM_FILES, ["f1", "f2"]) == ["f1", "f2"]
assert cast_parameter_value(PluginParameterType.FILE, ["one"]) == "one"
assert cast_parameter_value(PluginParameterType.FILE, "one") == "one"
with pytest.raises(ValueError, match="only accepts one file"):
cast_parameter_value(PluginParameterType.FILE, ["a", "b"])
@pytest.mark.parametrize(
("parameter_type", "value", "expected"),
[
(PluginParameterType.MODEL_SELECTOR, {"m": "gpt"}, {"m": "gpt"}),
(PluginParameterType.APP_SELECTOR, {"app": "a"}, {"app": "a"}),
(PluginParameterType.TOOLS_SELECTOR, [], []),
(PluginParameterType.ANY, {"k": "v"}, {"k": "v"}),
],
)
def test_cast_parameter_value_selectors_valid(self, parameter_type, value, expected):
assert cast_parameter_value(parameter_type, value) == expected
@pytest.mark.parametrize(
("parameter_type", "value", "message"),
[
(PluginParameterType.MODEL_SELECTOR, "bad", "selector must be a dictionary"),
(PluginParameterType.APP_SELECTOR, "bad", "selector must be a dictionary"),
(PluginParameterType.TOOLS_SELECTOR, "bad", "tools selector must be a list"),
(PluginParameterType.ANY, object(), "var selector must be"),
],
)
def test_cast_parameter_value_selectors_invalid(self, parameter_type, value, message):
with pytest.raises(ValueError, match=message):
cast_parameter_value(parameter_type, value)
@pytest.mark.parametrize(
("parameter_type", "value", "expected"),
[
(PluginParameterType.ARRAY, [1, 2], [1, 2]),
(PluginParameterType.ARRAY, "[1, 2]", [1, 2]),
(PluginParameterType.OBJECT, {"k": "v"}, {"k": "v"}),
(PluginParameterType.OBJECT, '{"a":1}', {"a": 1}),
],
)
def test_cast_parameter_value_array_and_object_valid(self, parameter_type, value, expected):
assert cast_parameter_value(parameter_type, value) == expected
@pytest.mark.parametrize(
("parameter_type", "value", "expected"),
[
(PluginParameterType.ARRAY, "bad-json", ["bad-json"]),
(PluginParameterType.OBJECT, "bad-json", {}),
],
)
def test_cast_parameter_value_array_and_object_invalid_json_fallback(self, parameter_type, value, expected):
assert cast_parameter_value(parameter_type, value) == expected
def test_cast_parameter_value_default_branch_and_wrapped_exception(self):
class _Unknown(StrEnum):
CUSTOM = "custom"
assert cast_parameter_value(_Unknown.CUSTOM, 12) == "12"
class _BadString:
def __str__(self):
raise RuntimeError("boom")
with pytest.raises(
ValueError,
match=r"The tool parameter value <.*_BadString object at .* is not in correct type of string\.",
):
cast_parameter_value(PluginParameterType.STRING, _BadString())
def test_init_frontend_parameter(self):
rule = PluginParameter(
name="choice",
label=self._label(),
required=True,
default="a",
options=[PluginParameterOption(value="a", label=self._label())],
)
assert init_frontend_parameter(rule, PluginParameterType.SELECT, None) == "a"
assert init_frontend_parameter(rule, PluginParameterType.NUMBER, 0) == 0
with pytest.raises(ValueError, match="not in options"):
init_frontend_parameter(rule, PluginParameterType.SELECT, "b")
required_rule = PluginParameter(name="required", label=self._label(), required=True, default=None)
with pytest.raises(ValueError, match="not found in tool config"):
init_frontend_parameter(required_rule, PluginParameterType.STRING, None)
class TestPluginDaemonEntities:
def test_credential_type_helpers(self):
assert CredentialType.API_KEY.get_name() == "API KEY"
assert CredentialType.OAUTH2.get_name() == "AUTH"
assert CredentialType.UNAUTHORIZED.get_name() == "UNAUTHORIZED"
class _FakeCredential:
value = "custom-type"
assert CredentialType.get_name(_FakeCredential()) == "CUSTOM TYPE"
assert CredentialType.API_KEY.is_editable() is True
assert CredentialType.OAUTH2.is_editable() is False
assert CredentialType.API_KEY.is_validate_allowed() is True
assert CredentialType.UNAUTHORIZED.is_validate_allowed() is False
assert set(CredentialType.values()) == {"api-key", "oauth2", "unauthorized"}
@pytest.mark.parametrize(
("raw", "expected"),
[
("api-key", CredentialType.API_KEY),
("api_key", CredentialType.API_KEY),
("oauth2", CredentialType.OAUTH2),
("oauth", CredentialType.OAUTH2),
("unauthorized", CredentialType.UNAUTHORIZED),
],
)
def test_credential_type_of(self, raw, expected):
assert CredentialType.of(raw) == expected
def test_credential_type_of_invalid(self):
with pytest.raises(ValueError, match="Invalid credential type"):
CredentialType.of("invalid")
class TestPluginRequestEntities:
def test_request_invoke_llm_converts_prompt_messages(self):
payload = RequestInvokeLLM(
provider="openai",
model="gpt-4",
mode="chat",
prompt_messages=[
{"role": "user", "content": "u"},
{"role": "assistant", "content": "a"},
{"role": "system", "content": "s"},
{"role": "tool", "content": "t", "tool_call_id": "call-1"},
],
)
assert isinstance(payload.prompt_messages[0], UserPromptMessage)
assert isinstance(payload.prompt_messages[1], AssistantPromptMessage)
assert isinstance(payload.prompt_messages[2], SystemPromptMessage)
assert isinstance(payload.prompt_messages[3], ToolPromptMessage)
def test_request_invoke_llm_prompt_messages_must_be_list(self):
with pytest.raises(ValidationError):
RequestInvokeLLM(provider="openai", model="gpt-4", mode="chat", prompt_messages="invalid") # type: ignore[arg-type]
def test_request_invoke_speech2text_hex_conversion_and_error(self):
payload = RequestInvokeSpeech2Text(provider="openai", model="m", file=binascii.hexlify(b"abc").decode())
assert payload.file == b"abc"
with pytest.raises(ValidationError):
RequestInvokeSpeech2Text(provider="openai", model="m", file=b"abc") # type: ignore[arg-type]
def test_trigger_invoke_event_response_variables_conversion(self):
converted = TriggerInvokeEventResponse(variables='{"a": 1}', cancelled=False)
assert converted.variables == {"a": 1}
passthrough = TriggerInvokeEventResponse(variables={"b": 2}, cancelled=True)
assert passthrough.variables == {"b": 2}
def test_trigger_dispatch_response_convert_response(self):
response = Response("ok", status=202, headers={"X-Req": "1"})
encoded = binascii.hexlify(serialize_response(response)).decode()
parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded)
assert parsed.response.status_code == 202
assert parsed.response.get_data() == b"ok"
with pytest.raises(ValidationError):
TriggerDispatchResponse(user_id="u", events=["e"], response="not-hex")
def test_trigger_dispatch_response_payload_default(self):
response = Response("ok", status=200)
encoded = binascii.hexlify(serialize_response(response)).decode()
parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded)
assert parsed.payload == {}

View File

@ -4,7 +4,10 @@ import pytest
from core.agent.entities import AgentInvokeMessage
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
class TestChunkMerger:
@ -458,3 +461,89 @@ class TestChunkMerger:
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert result[0].message.blob == b"FirstSecondThird"
class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
)
selector = ToolSelector(
provider_id="org/plugin/provider",
credential_id=None,
tool_name="search",
tool_description="search tool",
tool_configuration={"k": "v"},
tool_parameters={
"query": ToolSelector.Parameter(
name="query",
type=ToolParameter.ToolParameterType.STRING,
required=True,
description="query",
default="python",
options=[],
)
},
)
params = {"file": file_param, "selector": selector, "plain": 123}
converted = convert_parameters_to_plugin_format(params)
assert converted["file"]["url"] == "https://example.com/file.png"
assert converted["selector"]["provider_id"] == "org/plugin/provider"
assert converted["plain"] == 123
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",
)
selector_one = ToolSelector(
provider_id="org/plugin/provider",
credential_id="cred-1",
tool_name="t1",
tool_description="tool 1",
tool_configuration={},
tool_parameters={},
)
selector_two = ToolSelector(
provider_id="org/plugin/provider",
credential_id="cred-2",
tool_name="t2",
tool_description="tool 2",
tool_configuration={},
tool_parameters={},
)
params = {
"files": [file_one, file_two],
"selectors": [selector_one, selector_two],
"empty_list": [],
"mixed_list": [file_one, "raw"],
"none_value": None,
}
converted = convert_parameters_to_plugin_format(params)
assert [item["url"] for item in converted["files"]] == [
"https://example.com/a.txt",
"https://example.com/b.txt",
]
assert [item["tool_name"] for item in converted["selectors"]] == ["t1", "t2"]
assert converted["empty_list"] == []
assert converted["mixed_list"] == [file_one, "raw"]
assert converted["none_value"] is None

View File

@ -381,6 +381,54 @@ class TestEdgeCases:
assert response.status_code == 200
assert response.get_data() == binary_body
def test_deserialize_request_with_lf_only_newlines(self):
raw_data = b"POST /lf-only?x=1 HTTP/1.1\nHost: localhost\nX-Test: yes\n\npayload"
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/lf-only"
assert request.args.get("x") == "1"
assert request.headers.get("X-Test") == "yes"
assert request.get_data() == b"payload"
def test_deserialize_request_without_header_separator_uses_full_input_as_headers(self):
raw_data = b"GET /no-separator HTTP/1.1\nHost: localhost\nInvalidHeader\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/no-separator"
assert request.headers.get("Host") == "localhost"
assert request.headers.get("InvalidHeader") is None
def test_deserialize_request_empty_payload_raises(self):
with pytest.raises(ValueError, match="Empty HTTP request"):
deserialize_request(b"")
def test_deserialize_response_with_lf_only_newlines(self):
raw_data = b"HTTP/1.1 202 Accepted\nX-Test: yes\n\nbody"
response = deserialize_response(raw_data)
assert response.status_code == 202
assert response.headers.get("X-Test") == "yes"
assert response.get_data() == b"body"
def test_deserialize_response_without_header_separator_uses_full_input_as_headers(self):
raw_data = b"HTTP/1.1 204 No Content\nX-Test: yes\nInvalidHeader\n"
response = deserialize_response(raw_data)
assert response.status_code == 204
assert response.headers.get("X-Test") == "yes"
assert response.headers.get("InvalidHeader") is None
assert response.get_data() == b""
def test_deserialize_response_empty_payload_raises(self):
with pytest.raises(ValueError, match="Empty HTTP response"):
deserialize_response(b"")
class TestFileUploads:
def test_serialize_request_with_text_file_upload(self):

View File

@ -1,3 +1,4 @@
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from models.model import Conversation
@ -188,3 +191,328 @@ def get_chat_model_args():
context = "I am superman."
return model_config_mock, memory_config, prompt_messages, inputs, context
def test_get_prompt_dispatches_completion_and_chat_and_invalid():
transform = AdvancedPromptTransform()
model_config = MagicMock(spec=ModelConfigEntity)
completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic")
chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")]
transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")])
transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")])
completion_result = transform.get_prompt(
prompt_template=completion_template,
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert completion_result[0].content == "c"
chat_result = transform.get_prompt(
prompt_template=chat_template,
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert chat_result[0].content == "h"
invalid_result = transform.get_prompt(
prompt_template=cast(list, ["not-chat-model-message"]),
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert invalid_result == []
def test_completion_prompt_jinja2_with_files():
model_config_mock = MagicMock(spec=ModelConfigEntity)
transform = AdvancedPromptTransform()
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
with (
patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"),
patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content,
):
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_completion_model_prompt_messages(
prompt_template=completion_template,
inputs={"name": "John"},
query="",
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(messages) == 1
assert isinstance(messages[0].content, list)
assert messages[0].content[0].data == "https://example.com/image.jpg"
assert isinstance(messages[0].content[1], TextPromptMessageContent)
assert messages[0].content[1].data == "Hi John"
def test_completion_prompt_basic_sets_query_variable():
model_config_mock = MagicMock(spec=ModelConfigEntity)
transform = AdvancedPromptTransform()
template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic")
messages = transform._get_completion_model_prompt_messages(
prompt_template=template,
inputs={},
query="what?",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert messages[0].content == "Q=what?"
def test_chat_prompt_with_variable_template_and_context():
transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_config_mock = MagicMock(spec=ModelConfigEntity)
prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)]
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={"#node.name#": "john"},
query=None,
files=[],
context="context-text",
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(messages) == 1
assert isinstance(messages[0], SystemPromptMessage)
assert messages[0].content == "sys=john ctx=context-text"
def test_chat_prompt_jinja2_branch_and_invalid_edition():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")]
with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"):
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={"name": "John"},
query=None,
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert messages[0].content == "Hello John"
bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")]
with pytest.raises(ValueError, match="Invalid edition type"):
transform._get_chat_model_prompt_messages(
prompt_template=bad_prompt_template,
inputs={},
query=None,
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
def test_chat_prompt_query_template_and_query_only_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(enabled=False),
query_prompt_template="query={{#sys.query#}} ctx={{#context#}}",
)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={},
query="what",
files=[],
context="ctx",
memory_config=memory_config,
memory=None,
model_config=model_config_mock,
)
assert messages[-1].content == "query={{#sys.query#}} ctx=ctx"
def test_chat_prompt_memory_with_files_and_query():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
transform._append_chat_histories = MagicMock(
side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages
)
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={},
query="q",
files=[file],
context=None,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "q"
def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)]
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_with_last_user,
inputs={},
query=None,
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "u"
prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)]
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_without_last_user,
inputs={},
query=None,
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1], UserPromptMessage)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == ""
def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=[],
inputs={},
query="query-text",
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "query-text"
def test_set_context_query_histories_variable_helpers():
transform = AdvancedPromptTransform()
parser_context = PromptTemplateParser(template="{{#context#}}")
parser_query = PromptTemplateParser(template="{{#query#}}")
parser_hist = PromptTemplateParser(template="{{#histories#}}")
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
assert transform._set_context_variable(None, parser_context, {})["#context#"] == ""
assert transform._set_query_variable("", parser_query, {})["#query#"] == ""
assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x"
assert (
transform._set_histories_variable(
memory=None, # type: ignore[arg-type]
memory_config=memory_config,
raw_prompt="{{#histories#}}",
role_prefix=memory_config.role_prefix, # type: ignore[arg-type]
parser=parser_hist,
prompt_inputs={},
model_config=model_config_mock,
)["#histories#"]
== ""
)

View File

@ -2,12 +2,14 @@ from uuid import uuid4
from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
class MockMessage:
def __init__(self, id, parent_message_id):
def __init__(self, id, parent_message_id, answer="answer"):
self.id = id
self.parent_message_id = parent_message_id
self.answer = answer
def __getitem__(self, item):
return getattr(self, item)
@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages():
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
def test_extract_thread_messages_breaks_when_parent_is_none():
id1, id2 = str(uuid4()), str(uuid4())
messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0].id == id2
def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker):
id1, id2 = str(uuid4()), str(uuid4())
messages = [
MockMessage(id2, id1, answer=""), # newest generated message should be excluded
MockMessage(id1, UUID_NIL, answer="ok"),
]
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
mock_scalars.return_value.all.return_value = messages
length = get_thread_messages_length("conversation-1")
assert length == 1
mock_scalars.assert_called_once()
def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker):
id1, id2 = str(uuid4()), str(uuid4())
messages = [
MockMessage(id2, id1, answer="latest-answer"),
MockMessage(id1, UUID_NIL, answer="older-answer"),
]
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
mock_scalars.return_value.all.return_value = messages
length = get_thread_messages_length("conversation-2")
assert length == 2

View File

@ -1,6 +1,11 @@
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
@ -25,3 +30,82 @@ def test_dump_prompt_message():
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url
def test_prompt_messages_to_prompt_for_saving_chat_mode():
chat_messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="hello "),
ImagePromptMessageContent(
url="https://example.com/image1.jpg",
format="jpg",
mime_type="image/jpeg",
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
AudioPromptMessageContent(
url="https://example.com/audio1.mp3",
format="mp3",
mime_type="audio/mpeg",
),
TextPromptMessageContent(data="world"),
]
),
AssistantPromptMessage(
content="assistant-text",
tool_calls=[
{
"id": "tool-1",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"python"}'},
}
],
),
ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"),
UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type]
]
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages)
assert len(prompts) == 3
assert prompts[0]["role"] == "user"
assert prompts[0]["text"] == "hello world"
assert prompts[0]["files"][0]["type"] == "image"
assert prompts[0]["files"][1]["type"] == "audio"
assert prompts[1]["role"] == "assistant"
assert prompts[1]["text"] == "assistant-text"
assert prompts[1]["tool_calls"][0]["function"]["name"] == "search"
assert prompts[2]["role"] == "tool"
def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files():
completion_message_with_files = UserPromptMessage(
content=[
TextPromptMessageContent(data="first "),
TextPromptMessageContent(data="second"),
ImagePromptMessageContent(
url="https://example.com/image2.jpg",
format="jpg",
mime_type="image/jpeg",
detail=ImagePromptMessageContent.DETAIL.LOW,
),
]
)
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
ModelMode.COMPLETION, [completion_message_with_files]
)
assert prompts == [
{
"role": "user",
"text": "first second",
"files": prompts[0]["files"],
}
]
assert prompts[0]["files"][0]["type"] == "image"
completion_message_text_only = UserPromptMessage(content="plain text")
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
ModelMode.COMPLETION, [completion_message_text_only]
)
assert prompts == [{"role": "user", "text": "plain text"}]

View File

@ -1,4 +1,10 @@
# from unittest.mock import MagicMock
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.prompt.prompt_transform import PromptTransform
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
# from core.app.app_config.entities import ModelConfigEntity
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -9,44 +15,217 @@
# from core.prompt.prompt_transform import PromptTransform
# def test__calculate_rest_token():
# model_schema_mock = MagicMock(spec=AIModelEntity)
# parameter_rule_mock = MagicMock(spec=ParameterRule)
# parameter_rule_mock.name = "max_tokens"
# model_schema_mock.parameter_rules = [parameter_rule_mock]
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
class TestPromptTransform:
def test_resolve_model_runtime_requires_model_config_or_instance(self):
transform = PromptTransform()
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
# large_language_model_mock.get_num_tokens.return_value = 6
with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."):
transform._resolve_model_runtime()
# provider_mock = MagicMock(spec=ProviderEntity)
# provider_mock.provider = "openai"
def test_resolve_model_runtime_builds_model_instance_from_model_config(self):
transform = PromptTransform()
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = fake_model_schema
fake_model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials=None,
parameters=None,
stop=None,
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="config-model",
credentials={"api_key": "secret"},
parameters={"temperature": 0.1},
stop=["END"],
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
)
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
# provider_configuration_mock.provider = provider_mock
# provider_configuration_mock.model_settings = None
with patch(
"core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance
) as model_instance_cls:
model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config)
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
# provider_model_bundle_mock.configuration = provider_configuration_mock
model_instance_cls.assert_called_once_with(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model,
)
fake_model_type_instance.get_model_schema.assert_called_once_with(
model="resolved-model",
credentials={"api_key": "secret"},
)
assert model_instance is fake_model_instance
assert model_instance.credentials == {"api_key": "secret"}
assert model_instance.parameters == {"temperature": 0.1}
assert model_instance.stop == ["END"]
assert model_schema is fake_model_schema
# model_config_mock = MagicMock(spec=ModelConfigEntity)
# model_config_mock.model = "gpt-4"
# model_config_mock.credentials = {}
# model_config_mock.parameters = {"max_tokens": 50}
# model_config_mock.model_schema = model_schema_mock
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
def test_resolve_model_runtime_uses_model_config_schema_fallback(self):
transform = PromptTransform()
fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials={"api_key": "secret"},
parameters={},
)
model_config = SimpleNamespace(model_schema=fallback_schema)
# prompt_transform = PromptTransform()
resolved_model_instance, resolved_schema = transform._resolve_model_runtime(
model_config=model_config,
model_instance=model_instance,
)
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
assert resolved_model_instance is model_instance
assert resolved_schema is fallback_schema
# # Validate based on the mock configuration and expected logic
# expected_rest_tokens = (
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
# - model_config_mock.parameters["max_tokens"]
# - large_language_model_mock.get_num_tokens.return_value
# )
# assert rest_tokens == expected_rest_tokens
# assert rest_tokens == 6
def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self):
transform = PromptTransform()
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials={"api_key": "secret"},
parameters={},
)
with pytest.raises(ValueError, match="Model schema not found for the provided model instance."):
transform._resolve_model_runtime(model_instance=model_instance)
def test_calculate_rest_token_defaults_when_context_size_missing(self):
transform = PromptTransform()
fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0)
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
provider_model_bundle=object(),
model="test-model",
parameters={},
)
rest = transform._calculate_rest_token([], model_config=model_config)
assert rest == 2000
def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self):
transform = PromptTransform()
parameter_rule = SimpleNamespace(name="max_tokens", use_template=None)
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95)
fake_model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[parameter_rule],
)
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[parameter_rule],
),
provider_model_bundle=object(),
model="test-model",
parameters={"max_tokens": 50},
)
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
assert rest == 0
def test_calculate_rest_token_supports_use_template_parameter(self):
transform = PromptTransform()
parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens")
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20)
fake_model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
parameter_rules=[parameter_rule],
)
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
parameter_rules=[parameter_rule],
),
provider_model_bundle=object(),
model="test-model",
parameters={"max_tokens": 30},
)
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
assert rest == 150
def test_get_history_messages_from_memory_with_and_without_window(self):
transform = PromptTransform()
memory = MagicMock()
memory.get_history_prompt_text.return_value = "history"
memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3))
result = transform._get_history_messages_from_memory(
memory=memory,
memory_config=memory_config_with_window,
max_token_limit=100,
human_prefix="Human",
ai_prefix="Assistant",
)
assert result == "history"
memory.get_history_prompt_text.assert_called_with(
max_token_limit=100,
human_prefix="Human",
ai_prefix="Assistant",
message_limit=3,
)
memory.reset_mock()
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2))
transform._get_history_messages_from_memory(
memory=memory,
memory_config=memory_config_no_window,
max_token_limit=50,
)
memory.get_history_prompt_text.assert_called_with(max_token_limit=50)
def test_get_history_messages_list_from_memory_with_and_without_window(self):
transform = PromptTransform()
memory = MagicMock()
memory.get_history_prompt_messages.return_value = ["m1", "m2"]
memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2))
result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120)
assert result == ["m1", "m2"]
memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2)
memory.reset_mock()
memory.get_history_prompt_messages.return_value = ["only"]
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0))
result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10)
assert result == ["only"]
memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None)
def test_append_chat_histories_extends_prompt_messages(self, monkeypatch):
transform = PromptTransform()
memory = MagicMock()
memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None))
monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99)
monkeypatch.setattr(
transform,
"_get_history_messages_list_from_memory",
lambda memory, memory_config, max_token_limit: ["h1", "h2"],
)
result = transform._append_chat_histories(
memory=memory,
memory_config=memory_config,
prompt_messages=["p1"],
model_config=SimpleNamespace(),
)
assert result == ["p1", "h1", "h2"]

View File

@ -1,9 +1,29 @@
from unittest.mock import MagicMock
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from core.prompt.simple_prompt_transform import SimplePromptTransform
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from models.model import AppMode, Conversation
@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages():
assert len(prompt_messages) == 1
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt
def test_get_prompt_dispatches_chat_and_completion():
transform = SimplePromptTransform()
model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_chat.mode = "chat"
model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_completion.mode = "completion"
prompt_entity = SimpleNamespace(simple_prompt_template="hello")
transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None))
transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"]))
chat_messages, chat_stops = transform.get_prompt(
app_mode=AppMode.CHAT,
prompt_template_entity=prompt_entity,
inputs={"n": 1},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config_chat,
)
assert chat_messages == ["chat-msg"]
assert chat_stops is None
completion_messages, completion_stops = transform.get_prompt(
app_mode=AppMode.CHAT,
prompt_template_entity=prompt_entity,
inputs={"n": 1},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config_completion,
)
assert completion_messages == ["completion-msg"]
assert completion_stops == ["stop"]
def test_get_prompt_str_and_rules_type_validation_errors():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config.provider = "openai"
model_config.model = "gpt-4"
valid_prompt_template = SimplePromptTransform().get_prompt_template(
AppMode.CHAT, "openai", "gpt-4", "", False, False
)["prompt_template"]
bad_custom_keys = {
"prompt_template": valid_prompt_template,
"custom_variable_keys": "not-list",
"special_variable_keys": [],
"prompt_rules": {},
}
transform.get_prompt_template = MagicMock(return_value=bad_custom_keys)
with pytest.raises(TypeError, match="custom_variable_keys"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_special_keys = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": "not-list",
}
transform.get_prompt_template = MagicMock(return_value=bad_special_keys)
with pytest.raises(TypeError, match="special_variable_keys"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_prompt_template = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": [],
"prompt_template": 123,
}
transform.get_prompt_template = MagicMock(return_value=bad_prompt_template)
with pytest.raises(TypeError, match="PromptTemplateParser"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_prompt_rules = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": [],
"prompt_template": valid_prompt_template,
"prompt_rules": "not-dict",
}
transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules)
with pytest.raises(TypeError, match="prompt_rules"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
def test_chat_model_prompt_messages_uses_prompt_when_query_empty():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {}))
transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text"))
prompt_messages, _ = transform._get_chat_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt="",
inputs={},
query="",
files=[],
context=None,
memory=None,
model_config=model_config,
)
assert prompt_messages[0].content == "prompt-text"
transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None)
def test_completion_model_prompt_messages_empty_stops_becomes_none():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []}))
prompt_messages, stops = transform._get_completion_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt="",
inputs={},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config,
)
assert len(prompt_messages) == 1
assert stops is None
def test_get_last_user_message_with_files_and_context_files():
transform = SimplePromptTransform()
file = SimpleNamespace()
context_file = SimpleNamespace()
with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.side_effect = [
ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"),
ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"),
]
message = transform._get_last_user_message(
prompt="hello",
files=[file],
context_files=[context_file],
image_detail_config=None,
)
assert isinstance(message.content, list)
assert message.content[0].data == "https://example.com/a.jpg"
assert message.content[1].data == "https://example.com/b.jpg"
assert isinstance(message.content[2], TextPromptMessageContent)
assert message.content[2].data == "hello"
def test_prompt_file_name_branches():
transform = SimplePromptTransform()
assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat"
assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion"
assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion"
assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat"
def test_advanced_prompt_templates_constants_are_importable():
assert isinstance(CONTEXT, str)
assert isinstance(BAICHUAN_CONTEXT, str)
assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG
assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG
assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG
assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG

View File

@ -1,11 +1,14 @@
import dataclasses
import orjson
import pytest
from pydantic import BaseModel
from core.helper import encrypter
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.runtime import VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables.segment_group import SegmentGroup
from dify_graph.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
@ -23,6 +26,11 @@ from dify_graph.variables.segments import (
get_segment_discriminator,
)
from dify_graph.variables.types import SegmentType
from dify_graph.variables.utils import (
dumps_with_segments,
segment_orjson_default,
to_selector,
)
from dify_graph.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
@ -379,3 +387,125 @@ class TestSegmentDumpAndLoad:
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None
class TestSegmentAdditionalProperties:
def test_base_segment_text_log_markdown_size_and_to_object(self):
"""Ensure StringSegment exposes text, log, markdown, size and to_object."""
segment = StringSegment(value="hello")
assert segment.text == "hello"
assert segment.log == "hello"
assert segment.markdown == "hello"
assert segment.size > 0
assert segment.to_object() == "hello"
def test_none_segment_empty_outputs(self):
"""Ensure NoneSegment renders empty text, log and markdown."""
segment = NoneSegment()
assert segment.text == ""
assert segment.log == ""
assert segment.markdown == ""
def test_object_segment_json_outputs(self):
"""Ensure ObjectSegment renders JSON output for text, log and markdown."""
segment = ObjectSegment(value={"key": "", "n": 1})
assert segment.text == '{"key": "", "n": 1}'
assert segment.log == '{\n "key": "",\n "n": 1\n}'
assert segment.markdown == '{\n "key": "",\n "n": 1\n}'
def test_array_segment_text_and_markdown(self):
"""Ensure ArrayAnySegment handles empty/non-empty text and markdown rendering."""
empty_segment = ArrayAnySegment(value=[])
non_empty_segment = ArrayAnySegment(value=[1, "two"])
assert empty_segment.text == ""
assert non_empty_segment.text == "[1, 'two']"
assert non_empty_segment.markdown == "- 1\n- two"
def test_file_segment_properties(self):
"""Ensure FileSegment markdown, text and log fields match expected behavior."""
file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="doc.txt")
segment = FileSegment(value=file)
assert segment.markdown == "[doc.txt](https://example.com/file.txt)"
assert segment.log == ""
assert segment.text == ""
def test_array_string_segment_text_branches(self):
"""Ensure ArrayStringSegment text handling for empty and non-empty values."""
empty_segment = ArrayStringSegment(value=[])
non_empty_segment = ArrayStringSegment(value=["hello", "世界"])
assert empty_segment.text == ""
assert non_empty_segment.text == '["hello", "世界"]'
def test_array_file_segment_markdown_and_empty_text_log(self):
"""Ensure ArrayFileSegment markdown renders links and text/log stay empty."""
file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt")
file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt")
segment = ArrayFileSegment(value=[file1, file2])
assert segment.markdown == "[a.txt](https://example.com/file.txt)\n[b.txt](https://example.com/file.txt)"
assert segment.log == ""
assert segment.text == ""
class TestSegmentGroupAdditional:
def test_segment_group_markdown_and_to_object(self):
group = SegmentGroup(value=[StringSegment(value="A"), NoneSegment(), StringSegment(value="B")])
assert group.markdown == "AB"
assert group.to_object() == ["A", None, "B"]
class TestSegmentUtils:
def test_to_selector_without_paths(self):
assert to_selector("node-1", "output") == ["node-1", "output"]
def test_to_selector_with_paths(self):
assert to_selector("node-1", "output", ("a", "b")) == ["node-1", "output", "a", "b"]
def test_array_file_segment_serialization(self):
file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt")
file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt")
result = segment_orjson_default(ArrayFileSegment(value=[file1, file2]))
assert len(result) == 2
assert result[0]["filename"] == "a.txt"
assert result[1]["filename"] == "b.txt"
def test_file_segment_serialization(self):
file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="single.txt")
result = segment_orjson_default(FileSegment(value=file))
assert result["filename"] == "single.txt"
assert result["remote_url"] == "https://example.com/file.txt"
def test_segment_group_and_segment_serialization(self):
group = SegmentGroup(value=[StringSegment(value="a"), StringSegment(value="b")])
assert segment_orjson_default(group) == ["a", "b"]
assert segment_orjson_default(StringSegment(value="value")) == "value"
def test_segment_orjson_default_unsupported_type(self):
with pytest.raises(TypeError, match="not JSON serializable"):
segment_orjson_default(object())
def test_dumps_with_segments(self):
data = {
"segment": StringSegment(value="hello"),
"group": SegmentGroup(value=[StringSegment(value="x"), StringSegment(value="y")]),
1: "numeric-key",
}
dumped = dumps_with_segments(data)
loaded = orjson.loads(dumped)
assert loaded["segment"] == "hello"
assert loaded["group"] == ["x", "y"]
assert loaded["1"] == "numeric-key"

View File

@ -1,5 +1,7 @@
import pytest
from dify_graph.variables.segment_group import SegmentGroup
from dify_graph.variables.segments import StringSegment
from dify_graph.variables.types import ArrayValidation, SegmentType
@ -69,22 +71,36 @@ class TestSegmentTypeIsValidArrayValidation:
"""
def test_array_validation_all_success(self):
# Arrange
value = ["hello", "world", "foo"]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
# Act
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
# Assert
assert is_valid
def test_array_validation_all_fail(self):
# Arrange
value = ["hello", 123, "world"]
# Should return False, since 123 is not a string
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
# Act
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
# Assert
assert not is_valid
def test_array_validation_first(self):
# Arrange
value = ["hello", 123, None]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
# Act
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
# Assert
assert is_valid
def test_array_validation_none(self):
# Arrange
value = [1, 2, 3]
# validation is None, skip
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
# Act
is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
# Assert
assert is_valid
class TestSegmentTypeGetZeroValue:
@ -163,3 +179,62 @@ class TestSegmentTypeGetZeroValue:
for seg_type in unsupported_types:
with pytest.raises(ValueError, match="unsupported variable type"):
SegmentType.get_zero_value(seg_type)
class TestSegmentTypeInferSegmentType:
@pytest.mark.parametrize(
("value", "expected"),
[
([], SegmentType.ARRAY_NUMBER),
([1, 2, 3], SegmentType.ARRAY_NUMBER),
([1, 2.5], SegmentType.ARRAY_NUMBER),
(["a", "b"], SegmentType.ARRAY_STRING),
([{"k": "v"}], SegmentType.ARRAY_OBJECT),
([None], SegmentType.ARRAY_ANY),
([True, False], SegmentType.ARRAY_BOOLEAN),
([[1], [2]], SegmentType.ARRAY_ANY),
([1, "a"], SegmentType.ARRAY_ANY),
(None, SegmentType.NONE),
(True, SegmentType.BOOLEAN),
(1, SegmentType.INTEGER),
(1.2, SegmentType.FLOAT),
("abc", SegmentType.STRING),
({"k": "v"}, SegmentType.OBJECT),
],
)
def test_infer_segment_type_supported_values(self, value, expected):
assert SegmentType.infer_segment_type(value) == expected
class TestSegmentTypeAdditionalMethods:
def test_cast_value_for_bool_number_and_array_number(self):
assert SegmentType.cast_value(True, SegmentType.INTEGER) == 1
assert SegmentType.cast_value(False, SegmentType.NUMBER) == 0
assert SegmentType.cast_value([True, False], SegmentType.ARRAY_NUMBER) == [1, 0]
mixed = [True, 1]
assert SegmentType.cast_value(mixed, SegmentType.ARRAY_NUMBER) is mixed
assert SegmentType.cast_value("x", SegmentType.STRING) == "x"
def test_exposed_type_and_element_type(self):
assert SegmentType.INTEGER.exposed_type() == SegmentType.NUMBER
assert SegmentType.FLOAT.exposed_type() == SegmentType.NUMBER
assert SegmentType.STRING.exposed_type() == SegmentType.STRING
assert SegmentType.ARRAY_STRING.element_type() == SegmentType.STRING
assert SegmentType.ARRAY_ANY.element_type() is None
with pytest.raises(ValueError, match="element_type is only supported by array type"):
SegmentType.STRING.element_type()
def test_group_validation_for_segment_group_and_list(self):
valid_group = SegmentGroup(value=[StringSegment(value="a")])
assert SegmentType.GROUP.is_valid(valid_group) is True
assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True
assert SegmentType.GROUP.is_valid(["not-segment"]) is False
def test_unreachable_assertion_branch(self, monkeypatch):
monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False)
with pytest.raises(AssertionError, match="unreachable"):
SegmentType.ARRAY_STRING.is_valid(["a"])