mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
test: add UTs for api core.trigger (#32587)
This commit is contained in:
parent
3f3b788356
commit
4f835107b2
0
api/tests/unit_tests/core/trigger/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/__init__.py
Normal file
93
api/tests/unit_tests/core/trigger/conftest.py
Normal file
93
api/tests/unit_tests/core/trigger/conftest.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Shared factory helpers for core.trigger test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
EventEntity,
|
||||
EventIdentity,
|
||||
EventParameter,
|
||||
OAuthSchema,
|
||||
Subscription,
|
||||
SubscriptionConstructor,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
# Valid format for TriggerProviderID: org/plugin/provider
|
||||
VALID_PROVIDER_ID = "testorg/testplugin/testprovider"
|
||||
|
||||
|
||||
def i18n(text: str = "test") -> I18nObject:
|
||||
return I18nObject(en_US=text, zh_Hans=text)
|
||||
|
||||
|
||||
def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity:
|
||||
return EventEntity(
|
||||
identity=EventIdentity(author="a", name=name, label=i18n(name)),
|
||||
description=i18n(name),
|
||||
parameters=parameters or [],
|
||||
)
|
||||
|
||||
|
||||
def make_provider_entity(
|
||||
name: str = "test_provider",
|
||||
events: list[EventEntity] | None = None,
|
||||
constructor: SubscriptionConstructor | None = None,
|
||||
subscription_schema: list[ProviderConfig] | None = None,
|
||||
icon: str | None = "icon.png",
|
||||
icon_dark: str | None = None,
|
||||
) -> TriggerProviderEntity:
|
||||
return TriggerProviderEntity(
|
||||
identity=TriggerProviderIdentity(
|
||||
author="a",
|
||||
name=name,
|
||||
label=i18n(name),
|
||||
description=i18n(name),
|
||||
icon=icon,
|
||||
icon_dark=icon_dark,
|
||||
),
|
||||
events=events if events is not None else [make_event()],
|
||||
subscription_constructor=constructor,
|
||||
subscription_schema=subscription_schema or [],
|
||||
)
|
||||
|
||||
|
||||
def make_controller(
|
||||
entity: TriggerProviderEntity | None = None,
|
||||
tenant_id: str = "tenant-1",
|
||||
provider_id: str = VALID_PROVIDER_ID,
|
||||
) -> PluginTriggerProviderController:
|
||||
return PluginTriggerProviderController(
|
||||
entity=entity or make_provider_entity(),
|
||||
plugin_id="plugin-1",
|
||||
plugin_unique_identifier="uid-1",
|
||||
provider_id=TriggerProviderID(provider_id),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def make_subscription(**overrides: Any) -> Subscription:
|
||||
defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}}
|
||||
defaults.update(overrides)
|
||||
return Subscription(**defaults)
|
||||
|
||||
|
||||
def make_provider_config(
|
||||
name: str = "api_key", required: bool = True, config_type: str = "secret-input"
|
||||
) -> ProviderConfig:
|
||||
return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required)
|
||||
|
||||
|
||||
def make_constructor(
|
||||
credentials_schema: list[ProviderConfig] | None = None,
|
||||
oauth_schema: OAuthSchema | None = None,
|
||||
) -> SubscriptionConstructor:
|
||||
return SubscriptionConstructor(
|
||||
parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema
|
||||
)
|
||||
0
api/tests/unit_tests/core/trigger/debug/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/debug/__init__.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
Tests for core.trigger.debug.event_bus.TriggerDebugEventBus.
|
||||
|
||||
Covers: Lua-script dispatch/poll with Redis error resilience.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_dispatch_count(self, mock_redis):
|
||||
mock_redis.eval.return_value = 3
|
||||
event = MagicMock()
|
||||
event.model_dump_json.return_value = '{"test": true}'
|
||||
|
||||
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
|
||||
|
||||
assert result == 3
|
||||
mock_redis.eval.assert_called_once()
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_redis_error_returns_zero(self, mock_redis):
|
||||
mock_redis.eval.side_effect = RedisError("connection lost")
|
||||
event = MagicMock()
|
||||
event.model_dump_json.return_value = "{}"
|
||||
|
||||
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestPoll:
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_deserialized_event(self, mock_redis):
|
||||
event_json = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
).model_dump_json()
|
||||
mock_redis.eval.return_value = event_json
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "push"
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_none_when_no_event(self, mock_redis):
|
||||
mock_redis.eval.return_value = None
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_redis_error_returns_none(self, mock_redis):
|
||||
mock_redis.eval.side_effect = RedisError("timeout")
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Tests for core.trigger.debug.event_selectors.
|
||||
|
||||
Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory,
|
||||
and select_trigger_debug_events orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.trigger.debug.event_selectors import (
|
||||
PluginTriggerDebugEventPoller,
|
||||
ScheduleTriggerDebugEventPoller,
|
||||
WebhookTriggerDebugEventPoller,
|
||||
create_event_poller,
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
|
||||
from core.workflow.enums import NodeType
|
||||
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
|
||||
|
||||
|
||||
def _make_poller_args(node_config: dict | None = None) -> dict:
|
||||
return {
|
||||
"tenant_id": "t1",
|
||||
"user_id": "u1",
|
||||
"app_id": "a1",
|
||||
"node_config": node_config or {"data": {}},
|
||||
"node_id": "n1",
|
||||
}
|
||||
|
||||
|
||||
def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict:
|
||||
"""Valid node config for TriggerEventNodeData.model_validate."""
|
||||
return {
|
||||
"data": {
|
||||
"title": "test",
|
||||
"plugin_id": "org/testplugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "push",
|
||||
"subscription_id": "s1",
|
||||
"plugin_unique_identifier": "uid-1",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestPluginTriggerDebugEventPoller:
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_workflow_args_on_success(self, mock_bus):
|
||||
event = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
|
||||
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
|
||||
variables={"repo": "dify"},
|
||||
cancelled=False,
|
||||
)
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"repo": "dify"}
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_no_event(self, mock_bus):
|
||||
mock_bus.poll.return_value = None
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_invoke_cancelled(self, mock_bus):
|
||||
event = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
|
||||
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
|
||||
variables={},
|
||||
cancelled=True,
|
||||
)
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
|
||||
class TestWebhookTriggerDebugEventPoller:
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_uses_inputs_directly_when_present(self, mock_bus):
|
||||
event = WebhookDebugEvent(
|
||||
timestamp=100,
|
||||
request_id="r1",
|
||||
node_id="n1",
|
||||
payload={"inputs": {"key": "val"}, "webhook_data": {}},
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"key": "val"}
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_falls_back_to_webhook_data(self, mock_bus):
|
||||
event = WebhookDebugEvent(
|
||||
timestamp=100,
|
||||
request_id="r1",
|
||||
node_id="n1",
|
||||
payload={"webhook_data": {"body": "raw"}},
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc:
|
||||
mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"}
|
||||
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"parsed": "data"}
|
||||
mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"})
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_no_event(self, mock_bus):
|
||||
mock_bus.poll.return_value = None
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
|
||||
class TestScheduleTriggerDebugEventPoller:
|
||||
def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime):
|
||||
"""Set up mocks and create a schedule poller."""
|
||||
mock_redis.get.return_value = None
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
return ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.redis_client")
|
||||
@patch("core.trigger.debug.event_selectors.naive_utc_now")
|
||||
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
|
||||
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
|
||||
def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
next_run = datetime(2025, 1, 1, 13, 0, 0) # future
|
||||
mock_now.return_value = now
|
||||
mock_calc.return_value = next_run
|
||||
mock_ensure.return_value = next_run
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
|
||||
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.redis_client")
|
||||
@patch("core.trigger.debug.event_selectors.naive_utc_now")
|
||||
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
|
||||
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
|
||||
def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
|
||||
now = datetime(2025, 1, 1, 14, 0, 0)
|
||||
next_run = datetime(2025, 1, 1, 12, 0, 0) # past
|
||||
mock_now.return_value = now
|
||||
mock_calc.return_value = next_run
|
||||
mock_ensure.return_value = next_run
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
|
||||
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateEventPoller:
|
||||
def _workflow_with_node(self, node_type: NodeType):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = node_type
|
||||
return wf
|
||||
|
||||
def test_creates_plugin_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, PluginTriggerDebugEventPoller)
|
||||
|
||||
def test_creates_webhook_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, WebhookTriggerDebugEventPoller)
|
||||
|
||||
def test_creates_schedule_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, ScheduleTriggerDebugEventPoller)
|
||||
|
||||
def test_raises_for_unknown_type(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.START
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
|
||||
def test_raises_when_node_config_missing(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
|
||||
|
||||
class TestSelectTriggerDebugEvents:
|
||||
def test_returns_first_non_none_event(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
|
||||
app_model = MagicMock()
|
||||
app_model.tenant_id = "t1"
|
||||
app_model.id = "a1"
|
||||
|
||||
with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll:
|
||||
expected = MagicMock()
|
||||
mock_poll.return_value = expected
|
||||
|
||||
result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"])
|
||||
|
||||
assert result is expected
|
||||
|
||||
def test_returns_none_when_no_events(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
|
||||
app_model = MagicMock()
|
||||
app_model.tenant_id = "t1"
|
||||
app_model.id = "a1"
|
||||
|
||||
with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None):
|
||||
result = select_trigger_debug_events(wf, app_model, "u1", ["n1"])
|
||||
|
||||
assert result is None
|
||||
332
api/tests/unit_tests/core/trigger/test_provider.py
Normal file
332
api/tests/unit_tests/core/trigger/test_provider.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""
|
||||
Tests for core.trigger.provider.PluginTriggerProviderController.
|
||||
|
||||
Covers: to_api_entity creation-method logic, credential validation pipeline,
|
||||
schema resolution by type, event lookup, dispatch/invoke/subscribe delegation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
EventParameter,
|
||||
EventParameterType,
|
||||
OAuthSchema,
|
||||
TriggerCreationMethod,
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
from tests.unit_tests.core.trigger.conftest import (
|
||||
i18n,
|
||||
make_constructor,
|
||||
make_controller,
|
||||
make_event,
|
||||
make_provider_config,
|
||||
make_provider_entity,
|
||||
make_subscription,
|
||||
)
|
||||
|
||||
ICON_URL = "https://cdn/icon.png"
|
||||
|
||||
|
||||
class TestToApiEntity:
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_includes_icons_when_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png"))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.icon == ICON_URL
|
||||
assert api.icon_dark == ICON_URL
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_icons_none_when_absent(self, mock_plugin_svc):
|
||||
ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.icon is None
|
||||
assert api.icon_dark is None
|
||||
mock_plugin_svc.get_plugin_icon_url.assert_not_called()
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_manual_only_without_schemas(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL]
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert TriggerCreationMethod.OAUTH in api.supported_creation_methods
|
||||
assert TriggerCreationMethod.MANUAL in api.supported_creation_methods
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert TriggerCreationMethod.APIKEY in api.supported_creation_methods
|
||||
|
||||
|
||||
class TestGetEvent:
|
||||
def test_returns_matching_event(self):
|
||||
evt = make_event("push")
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")]))
|
||||
|
||||
assert ctrl.get_event("push") is evt
|
||||
|
||||
def test_returns_none_for_unknown(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
|
||||
|
||||
assert ctrl.get_event("nonexistent") is None
|
||||
|
||||
|
||||
class TestGetSubscriptionDefaultProperties:
|
||||
def test_returns_defaults_skipping_none(self):
|
||||
config1 = make_provider_config("key1")
|
||||
config1.default = "val1"
|
||||
config2 = make_provider_config("key2")
|
||||
config2.default = None
|
||||
ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2]))
|
||||
|
||||
props = ctrl.get_subscription_default_properties()
|
||||
|
||||
assert props == {"key1": "val1"}
|
||||
|
||||
|
||||
class TestValidateCredentials:
|
||||
def test_raises_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
|
||||
with pytest.raises(ValueError, match="Subscription constructor not found"):
|
||||
ctrl.validate_credentials("u1", {"key": "val"})
|
||||
|
||||
def test_raises_for_missing_required_field(self):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
|
||||
with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"):
|
||||
ctrl.validate_credentials("u1", {})
|
||||
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_passes_with_valid_credentials(self, mock_client):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
mock_client.return_value.validate_provider_credentials.return_value = True
|
||||
|
||||
ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise
|
||||
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_raises_when_plugin_rejects(self, mock_client):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
mock_client.return_value.validate_provider_credentials.return_value = None
|
||||
|
||||
with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"):
|
||||
ctrl.validate_credentials("u1", {"api_key": "bad"})
|
||||
|
||||
|
||||
class TestGetSupportedCredentialTypes:
|
||||
def test_empty_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
assert ctrl.get_supported_credential_types() == []
|
||||
|
||||
def test_oauth_only(self):
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.OAUTH2 in types
|
||||
assert CredentialType.API_KEY not in types
|
||||
|
||||
def test_apikey_only(self):
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.API_KEY in types
|
||||
assert CredentialType.OAUTH2 not in types
|
||||
|
||||
def test_both(self):
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")])
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(
|
||||
constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth)
|
||||
)
|
||||
)
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.OAUTH2 in types
|
||||
assert CredentialType.API_KEY in types
|
||||
|
||||
|
||||
class TestGetCredentialsSchema:
|
||||
def test_returns_empty_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
assert ctrl.get_credentials_schema(CredentialType.API_KEY) == []
|
||||
|
||||
def test_returns_apikey_credentials(self):
|
||||
cfg = make_provider_config("token")
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg])))
|
||||
|
||||
result = ctrl.get_credentials_schema(CredentialType.API_KEY)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "token"
|
||||
|
||||
def test_returns_oauth_credentials(self):
|
||||
oauth_cred = make_provider_config("oauth_token")
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
result = ctrl.get_credentials_schema(CredentialType.OAUTH2)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "oauth_token"
|
||||
|
||||
def test_unauthorized_returns_empty(self):
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == []
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor()))
|
||||
with pytest.raises(ValueError, match="Invalid credential type"):
|
||||
ctrl.get_credentials_schema("bogus_type")
|
||||
|
||||
|
||||
class TestGetEventParameters:
|
||||
def test_returns_params_for_known_event(self):
|
||||
param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING)
|
||||
evt = make_event("push", parameters=[param])
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[evt]))
|
||||
|
||||
result = ctrl.get_event_parameters("push")
|
||||
|
||||
assert "branch" in result
|
||||
assert result["branch"].name == "branch"
|
||||
|
||||
def test_returns_empty_for_unknown_event(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
|
||||
|
||||
assert ctrl.get_event_parameters("nonexistent") == {}
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_delegates_to_client(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
expected = MagicMock()
|
||||
mock_client.return_value.dispatch_event.return_value = expected
|
||||
|
||||
result = ctrl.dispatch(
|
||||
request=MagicMock(),
|
||||
subscription=make_subscription(),
|
||||
credentials={"k": "v"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
mock_client.return_value.dispatch_event.assert_called_once()
|
||||
|
||||
|
||||
class TestInvokeTriggerEvent:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_delegates_to_client(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
expected = MagicMock()
|
||||
mock_client.return_value.invoke_trigger_event.return_value = expected
|
||||
|
||||
result = ctrl.invoke_trigger_event(
|
||||
user_id="u1",
|
||||
event_name="push",
|
||||
parameters={},
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
subscription=make_subscription(),
|
||||
request=MagicMock(),
|
||||
payload={},
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
class TestSubscribeTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_returns_validated_subscription(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.subscribe.return_value.subscription = {
|
||||
"expires_at": 123,
|
||||
"endpoint": "https://e",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
result = ctrl.subscribe_trigger(
|
||||
user_id="u1",
|
||||
endpoint="https://e",
|
||||
parameters={},
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result.endpoint == "https://e"
|
||||
|
||||
|
||||
class TestUnsubscribeTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_returns_validated_result(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"}
|
||||
|
||||
result = ctrl.unsubscribe_trigger(
|
||||
user_id="u1",
|
||||
subscription=make_subscription(),
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestRefreshTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_uses_system_user_id(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.refresh.return_value.subscription = {
|
||||
"expires_at": 456,
|
||||
"endpoint": "https://e",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY)
|
||||
|
||||
call_kwargs = mock_client.return_value.refresh.call_args[1]
|
||||
assert call_kwargs["user_id"] == "system"
|
||||
307
api/tests/unit_tests/core/trigger/test_trigger_manager.py
Normal file
307
api/tests/unit_tests/core/trigger/test_trigger_manager.py
Normal file
@ -0,0 +1,307 @@
|
||||
"""
|
||||
Tests for core.trigger.trigger_manager.TriggerManager.
|
||||
|
||||
Covers: icon URL construction, provider listing with error resilience,
|
||||
double-check lock caching, error translation, EventIgnoreError -> cancelled,
|
||||
and delegation to provider controller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
|
||||
from core.trigger.errors import EventIgnoreError
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from tests.unit_tests.core.trigger.conftest import (
|
||||
VALID_PROVIDER_ID,
|
||||
make_controller,
|
||||
make_provider_entity,
|
||||
make_subscription,
|
||||
)
|
||||
|
||||
PID = TriggerProviderID(VALID_PROVIDER_ID)
|
||||
PID_STR = str(PID)
|
||||
|
||||
|
||||
class TestGetTriggerPluginIcon:
|
||||
@patch("core.trigger.trigger_manager.dify_config")
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_builds_correct_url(self, mock_client, mock_config):
|
||||
mock_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
provider = MagicMock()
|
||||
provider.declaration.identity.icon = "my-icon.svg"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID)
|
||||
|
||||
assert "tenant_id=tenant-1" in url
|
||||
assert "filename=my-icon.svg" in url
|
||||
assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon")
|
||||
|
||||
|
||||
class TestListPluginTriggerProviders:
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_wraps_entities_into_controllers(self, mock_client):
|
||||
entity = MagicMock()
|
||||
entity.declaration = make_provider_entity("p1")
|
||||
entity.plugin_id = "plugin-1"
|
||||
entity.plugin_unique_identifier = "uid-1"
|
||||
entity.provider = VALID_PROVIDER_ID
|
||||
mock_client.return_value.fetch_trigger_providers.return_value = [entity]
|
||||
|
||||
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
|
||||
|
||||
assert len(controllers) == 1
|
||||
assert controllers[0].plugin_id == "plugin-1"
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_skips_failing_providers(self, mock_client):
|
||||
good = MagicMock()
|
||||
good.declaration = make_provider_entity("good")
|
||||
good.plugin_id = "good-plugin"
|
||||
good.plugin_unique_identifier = "uid-good"
|
||||
good.provider = VALID_PROVIDER_ID
|
||||
|
||||
bad = MagicMock()
|
||||
bad.declaration = make_provider_entity("bad")
|
||||
bad.plugin_id = "bad-plugin"
|
||||
bad.plugin_unique_identifier = "uid-bad"
|
||||
bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation
|
||||
|
||||
mock_client.return_value.fetch_trigger_providers.return_value = [bad, good]
|
||||
|
||||
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
|
||||
|
||||
assert len(controllers) == 1
|
||||
assert controllers[0].plugin_id == "good-plugin"
|
||||
|
||||
|
||||
class TestGetTriggerProvider:
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_initializes_context_on_first_call(self, mock_ctx, mock_client):
|
||||
# get() called 3 times: (1) try block, (2) after set, (3) under lock
|
||||
mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}]
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
provider = MagicMock()
|
||||
provider.declaration = make_provider_entity()
|
||||
provider.plugin_id = "p1"
|
||||
provider.plugin_unique_identifier = "uid-1"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
mock_ctx.plugin_trigger_providers.set.assert_called_once_with({})
|
||||
mock_ctx.plugin_trigger_providers_lock.set.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_returns_cached_without_fetch(self, mock_ctx, mock_client):
|
||||
cached = make_controller()
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached}
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is cached
|
||||
mock_client.return_value.fetch_trigger_provider.assert_not_called()
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client):
|
||||
cached = make_controller()
|
||||
mock_ctx.plugin_trigger_providers.get.side_effect = [
|
||||
{}, # first check misses
|
||||
{PID_STR: cached}, # under-lock check hits
|
||||
]
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is cached
|
||||
mock_client.return_value.fetch_trigger_provider.assert_not_called()
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client):
|
||||
cache: dict = {}
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = cache
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
provider = MagicMock()
|
||||
provider.declaration = make_provider_entity()
|
||||
provider.plugin_id = "p1"
|
||||
provider.plugin_unique_identifier = "uid-1"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is not None
|
||||
assert PID_STR in cache
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_none_fetch_raises_value_error(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing"))
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error")
|
||||
|
||||
with pytest.raises(PluginDaemonError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
|
||||
|
||||
|
||||
class TestListAllTriggerProviders:
|
||||
@patch.object(TriggerManager, "list_plugin_trigger_providers")
|
||||
def test_delegates_to_list_plugin(self, mock_list):
|
||||
expected = [make_controller()]
|
||||
mock_list.return_value = expected
|
||||
|
||||
assert TriggerManager.list_all_trigger_providers("t1") is expected
|
||||
mock_list.assert_called_once_with("t1")
|
||||
|
||||
|
||||
class TestListTriggersByProvider:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_returns_provider_events(self, mock_get):
|
||||
ctrl = make_controller()
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.list_triggers_by_provider("t1", PID)
|
||||
|
||||
assert result == ctrl.get_events()
|
||||
|
||||
|
||||
class TestInvokeTriggerEvent:
|
||||
def _args(self):
|
||||
return {
|
||||
"tenant_id": "t1",
|
||||
"user_id": "u1",
|
||||
"provider_id": PID,
|
||||
"event_name": "on_push",
|
||||
"parameters": {"branch": "main"},
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_type": CredentialType.API_KEY,
|
||||
"subscription": make_subscription(),
|
||||
"request": MagicMock(),
|
||||
"payload": {"action": "push"},
|
||||
}
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_returns_invoke_response(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False)
|
||||
ctrl.invoke_trigger_event.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
assert result is expected
|
||||
assert result.cancelled is False
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_event_ignore_returns_cancelled(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip")
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
assert result.cancelled is True
|
||||
assert result.variables == {}
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_other_errors_propagate(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
ctrl.invoke_trigger_event.side_effect = RuntimeError("boom")
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
|
||||
class TestSubscribeTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = make_subscription()
|
||||
ctrl.subscribe_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.subscribe_trigger(
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
provider_id=PID,
|
||||
endpoint="https://hook.test",
|
||||
parameters={"f": "all"},
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
ctrl.subscribe_trigger.assert_called_once()
|
||||
|
||||
|
||||
class TestUnsubscribeTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = MagicMock()
|
||||
ctrl.unsubscribe_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
sub = make_subscription()
|
||||
|
||||
result = TriggerManager.unsubscribe_trigger(
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
provider_id=PID,
|
||||
subscription=sub,
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
class TestRefreshTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = make_subscription()
|
||||
ctrl.refresh_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.refresh_trigger(
|
||||
tenant_id="t1",
|
||||
provider_id=PID,
|
||||
subscription=make_subscription(),
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
0
api/tests/unit_tests/core/trigger/utils/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/utils/__init__.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Tests for core.trigger.utils.encryption — masking logic and cache key generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.utils.encryption import (
|
||||
TriggerProviderCredentialsCache,
|
||||
TriggerProviderOAuthClientParamsCache,
|
||||
TriggerProviderPropertiesCache,
|
||||
masked_credentials,
|
||||
)
|
||||
|
||||
|
||||
def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig:
|
||||
return ProviderConfig(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
type=field_type,
|
||||
)
|
||||
|
||||
|
||||
class TestMaskedCredentials:
|
||||
def test_short_secret_fully_masked(self):
|
||||
schema = [_make_schema("key", "secret-input")]
|
||||
result = masked_credentials(schema, {"key": "ab"})
|
||||
assert result["key"] == "**"
|
||||
|
||||
def test_long_secret_partially_masked(self):
|
||||
schema = [_make_schema("key", "secret-input")]
|
||||
result = masked_credentials(schema, {"key": "abcdef"})
|
||||
assert result["key"].startswith("ab")
|
||||
assert result["key"].endswith("ef")
|
||||
assert "**" in result["key"]
|
||||
|
||||
def test_non_secret_field_unchanged(self):
|
||||
schema = [_make_schema("host", "text-input")]
|
||||
result = masked_credentials(schema, {"host": "example.com"})
|
||||
assert result["host"] == "example.com"
|
||||
|
||||
def test_unknown_key_passes_through(self):
|
||||
result = masked_credentials([], {"unknown": "value"})
|
||||
assert result["unknown"] == "value"
|
||||
|
||||
|
||||
class TestCacheKeyGeneration:
|
||||
def test_credentials_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
assert "c1" in cache.cache_key
|
||||
|
||||
def test_oauth_client_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
|
||||
def test_properties_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
assert "s1" in cache.cache_key
|
||||
@ -0,0 +1,31 @@
|
||||
"""Tests for core.trigger.utils.endpoint — URL generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.trigger.utils import endpoint
|
||||
|
||||
|
||||
class TestGeneratePluginTriggerEndpointUrl:
|
||||
def test_builds_correct_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123")
|
||||
|
||||
assert url == "https://api.example.com/triggers/plugin/endpoint-123"
|
||||
|
||||
|
||||
class TestGenerateWebhookTriggerEndpoint:
|
||||
def test_non_debug_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False)
|
||||
|
||||
assert url == "https://api.example.com/triggers/webhook/sub-456"
|
||||
|
||||
def test_debug_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True)
|
||||
|
||||
assert url == "https://api.example.com/triggers/webhook-debug/sub-456"
|
||||
23
api/tests/unit_tests/core/trigger/utils/test_utils_locks.py
Normal file
23
api/tests/unit_tests/core/trigger/utils/test_utils_locks.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Tests for core.trigger.utils.locks — Redis lock key builders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys
|
||||
|
||||
|
||||
class TestBuildTriggerRefreshLockKey:
|
||||
def test_correct_format(self):
|
||||
key = build_trigger_refresh_lock_key("tenant-1", "sub-1")
|
||||
|
||||
assert key == "trigger_provider_refresh_lock:tenant-1_sub-1"
|
||||
|
||||
|
||||
class TestBuildTriggerRefreshLockKeys:
|
||||
def test_maps_over_pairs(self):
|
||||
pairs = [("t1", "s1"), ("t2", "s2")]
|
||||
|
||||
keys = build_trigger_refresh_lock_keys(pairs)
|
||||
|
||||
assert len(keys) == 2
|
||||
assert keys[0] == "trigger_provider_refresh_lock:t1_s1"
|
||||
assert keys[1] == "trigger_provider_refresh_lock:t2_s2"
|
||||
Loading…
Reference in New Issue
Block a user