test: add UTs for api core.trigger (#32587)

This commit is contained in:
Dev Sharma 2026-03-10 07:48:32 +05:30 committed by GitHub
parent 3f3b788356
commit 4f835107b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1217 additions and 0 deletions

View File

@ -0,0 +1,93 @@
"""Shared factory helpers for core.trigger test suite."""
from __future__ import annotations
from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.trigger.entities.entities import (
EventEntity,
EventIdentity,
EventParameter,
OAuthSchema,
Subscription,
SubscriptionConstructor,
TriggerProviderEntity,
TriggerProviderIdentity,
)
from core.trigger.provider import PluginTriggerProviderController
from models.provider_ids import TriggerProviderID
# Valid format for TriggerProviderID: org/plugin/provider
VALID_PROVIDER_ID = "testorg/testplugin/testprovider"
def i18n(text: str = "test") -> I18nObject:
return I18nObject(en_US=text, zh_Hans=text)
def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity:
return EventEntity(
identity=EventIdentity(author="a", name=name, label=i18n(name)),
description=i18n(name),
parameters=parameters or [],
)
def make_provider_entity(
name: str = "test_provider",
events: list[EventEntity] | None = None,
constructor: SubscriptionConstructor | None = None,
subscription_schema: list[ProviderConfig] | None = None,
icon: str | None = "icon.png",
icon_dark: str | None = None,
) -> TriggerProviderEntity:
return TriggerProviderEntity(
identity=TriggerProviderIdentity(
author="a",
name=name,
label=i18n(name),
description=i18n(name),
icon=icon,
icon_dark=icon_dark,
),
events=events if events is not None else [make_event()],
subscription_constructor=constructor,
subscription_schema=subscription_schema or [],
)
def make_controller(
entity: TriggerProviderEntity | None = None,
tenant_id: str = "tenant-1",
provider_id: str = VALID_PROVIDER_ID,
) -> PluginTriggerProviderController:
return PluginTriggerProviderController(
entity=entity or make_provider_entity(),
plugin_id="plugin-1",
plugin_unique_identifier="uid-1",
provider_id=TriggerProviderID(provider_id),
tenant_id=tenant_id,
)
def make_subscription(**overrides: Any) -> Subscription:
defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}}
defaults.update(overrides)
return Subscription(**defaults)
def make_provider_config(
name: str = "api_key", required: bool = True, config_type: str = "secret-input"
) -> ProviderConfig:
return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required)
def make_constructor(
credentials_schema: list[ProviderConfig] | None = None,
oauth_schema: OAuthSchema | None = None,
) -> SubscriptionConstructor:
return SubscriptionConstructor(
parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema
)

View File

@ -0,0 +1,93 @@
"""
Tests for core.trigger.debug.event_bus.TriggerDebugEventBus.
Covers: Lua-script dispatch/poll with Redis error resilience.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from redis import RedisError
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import PluginTriggerDebugEvent
class TestDispatch:
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_dispatch_count(self, mock_redis):
mock_redis.eval.return_value = 3
event = MagicMock()
event.model_dump_json.return_value = '{"test": true}'
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
assert result == 3
mock_redis.eval.assert_called_once()
@patch("core.trigger.debug.event_bus.redis_client")
def test_redis_error_returns_zero(self, mock_redis):
mock_redis.eval.side_effect = RedisError("connection lost")
event = MagicMock()
event.model_dump_json.return_value = "{}"
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
assert result == 0
class TestPoll:
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_deserialized_event(self, mock_redis):
event_json = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
).model_dump_json()
mock_redis.eval.return_value = event_json
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is not None
assert result.name == "push"
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_none_when_no_event(self, mock_redis):
mock_redis.eval.return_value = None
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is None
@patch("core.trigger.debug.event_bus.redis_client")
def test_redis_error_returns_none(self, mock_redis):
mock_redis.eval.side_effect = RedisError("timeout")
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is None

View File

@ -0,0 +1,276 @@
"""
Tests for core.trigger.debug.event_selectors.
Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory,
and select_trigger_debug_events orchestrator.
"""
from __future__ import annotations
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.debug.event_selectors import (
PluginTriggerDebugEventPoller,
ScheduleTriggerDebugEventPoller,
WebhookTriggerDebugEventPoller,
create_event_poller,
select_trigger_debug_events,
)
from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
from core.workflow.enums import NodeType
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
def _make_poller_args(node_config: dict | None = None) -> dict:
return {
"tenant_id": "t1",
"user_id": "u1",
"app_id": "a1",
"node_config": node_config or {"data": {}},
"node_id": "n1",
}
def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict:
"""Valid node config for TriggerEventNodeData.model_validate."""
return {
"data": {
"title": "test",
"plugin_id": "org/testplugin",
"provider_id": provider_id,
"event_name": "push",
"subscription_id": "s1",
"plugin_unique_identifier": "uid-1",
}
}
class TestPluginTriggerDebugEventPoller:
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_workflow_args_on_success(self, mock_bus):
event = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
)
mock_bus.poll.return_value = event
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
variables={"repo": "dify"},
cancelled=False,
)
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"repo": "dify"}
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_no_event(self, mock_bus):
mock_bus.poll.return_value = None
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
assert poller.poll() is None
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_invoke_cancelled(self, mock_bus):
event = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
)
mock_bus.poll.return_value = event
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
variables={},
cancelled=True,
)
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
assert poller.poll() is None
class TestWebhookTriggerDebugEventPoller:
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_uses_inputs_directly_when_present(self, mock_bus):
event = WebhookDebugEvent(
timestamp=100,
request_id="r1",
node_id="n1",
payload={"inputs": {"key": "val"}, "webhook_data": {}},
)
mock_bus.poll.return_value = event
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"key": "val"}
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_falls_back_to_webhook_data(self, mock_bus):
event = WebhookDebugEvent(
timestamp=100,
request_id="r1",
node_id="n1",
payload={"webhook_data": {"body": "raw"}},
)
mock_bus.poll.return_value = event
with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc:
mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"}
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"parsed": "data"}
mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"})
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_no_event(self, mock_bus):
mock_bus.poll.return_value = None
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
assert poller.poll() is None
class TestScheduleTriggerDebugEventPoller:
def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime):
"""Set up mocks and create a schedule poller."""
mock_redis.get.return_value = None
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
return ScheduleTriggerDebugEventPoller(**_make_poller_args())
@patch("core.trigger.debug.event_selectors.redis_client")
@patch("core.trigger.debug.event_selectors.naive_utc_now")
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
now = datetime(2025, 1, 1, 12, 0, 0)
next_run = datetime(2025, 1, 1, 13, 0, 0) # future
mock_now.return_value = now
mock_calc.return_value = next_run
mock_ensure.return_value = next_run
mock_redis.get.return_value = None
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
assert poller.poll() is None
@patch("core.trigger.debug.event_selectors.redis_client")
@patch("core.trigger.debug.event_selectors.naive_utc_now")
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
now = datetime(2025, 1, 1, 14, 0, 0)
next_run = datetime(2025, 1, 1, 12, 0, 0) # past
mock_now.return_value = now
mock_calc.return_value = next_run
mock_ensure.return_value = next_run
mock_redis.get.return_value = None
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
mock_redis.delete.assert_called_once()
class TestCreateEventPoller:
def _workflow_with_node(self, node_type: NodeType):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = node_type
return wf
def test_creates_plugin_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, PluginTriggerDebugEventPoller)
def test_creates_webhook_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, WebhookTriggerDebugEventPoller)
def test_creates_schedule_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, ScheduleTriggerDebugEventPoller)
def test_raises_for_unknown_type(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.START
with pytest.raises(ValueError):
create_event_poller(wf, "t1", "u1", "a1", "n1")
def test_raises_when_node_config_missing(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = None
with pytest.raises(ValueError):
create_event_poller(wf, "t1", "u1", "a1", "n1")
class TestSelectTriggerDebugEvents:
def test_returns_first_non_none_event(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
app_model = MagicMock()
app_model.tenant_id = "t1"
app_model.id = "a1"
with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll:
expected = MagicMock()
mock_poll.return_value = expected
result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"])
assert result is expected
def test_returns_none_when_no_events(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
app_model = MagicMock()
app_model.tenant_id = "t1"
app_model.id = "a1"
with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None):
result = select_trigger_debug_events(wf, app_model, "u1", ["n1"])
assert result is None

View File

@ -0,0 +1,332 @@
"""
Tests for core.trigger.provider.PluginTriggerProviderController.
Covers: to_api_entity creation-method logic, credential validation pipeline,
schema resolution by type, event lookup, dispatch/invoke/subscribe delegation.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
EventParameter,
EventParameterType,
OAuthSchema,
TriggerCreationMethod,
)
from core.trigger.errors import TriggerProviderCredentialValidationError
from tests.unit_tests.core.trigger.conftest import (
i18n,
make_constructor,
make_controller,
make_event,
make_provider_config,
make_provider_entity,
make_subscription,
)
ICON_URL = "https://cdn/icon.png"
class TestToApiEntity:
@patch("core.trigger.provider.PluginService")
def test_includes_icons_when_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png"))
api = ctrl.to_api_entity()
assert api.icon == ICON_URL
assert api.icon_dark == ICON_URL
@patch("core.trigger.provider.PluginService")
def test_icons_none_when_absent(self, mock_plugin_svc):
ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None))
api = ctrl.to_api_entity()
assert api.icon is None
assert api.icon_dark is None
mock_plugin_svc.get_plugin_icon_url.assert_not_called()
@patch("core.trigger.provider.PluginService")
def test_manual_only_without_schemas(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(entity=make_provider_entity(constructor=None))
api = ctrl.to_api_entity()
assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL]
@patch("core.trigger.provider.PluginService")
def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
api = ctrl.to_api_entity()
assert TriggerCreationMethod.OAUTH in api.supported_creation_methods
assert TriggerCreationMethod.MANUAL in api.supported_creation_methods
@patch("core.trigger.provider.PluginService")
def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
api = ctrl.to_api_entity()
assert TriggerCreationMethod.APIKEY in api.supported_creation_methods
class TestGetEvent:
def test_returns_matching_event(self):
evt = make_event("push")
ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")]))
assert ctrl.get_event("push") is evt
def test_returns_none_for_unknown(self):
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
assert ctrl.get_event("nonexistent") is None
class TestGetSubscriptionDefaultProperties:
def test_returns_defaults_skipping_none(self):
config1 = make_provider_config("key1")
config1.default = "val1"
config2 = make_provider_config("key2")
config2.default = None
ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2]))
props = ctrl.get_subscription_default_properties()
assert props == {"key1": "val1"}
class TestValidateCredentials:
def test_raises_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
with pytest.raises(ValueError, match="Subscription constructor not found"):
ctrl.validate_credentials("u1", {"key": "val"})
def test_raises_for_missing_required_field(self):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"):
ctrl.validate_credentials("u1", {})
@patch("core.trigger.provider.PluginTriggerClient")
def test_passes_with_valid_credentials(self, mock_client):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
mock_client.return_value.validate_provider_credentials.return_value = True
ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise
@patch("core.trigger.provider.PluginTriggerClient")
def test_raises_when_plugin_rejects(self, mock_client):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
mock_client.return_value.validate_provider_credentials.return_value = None
with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"):
ctrl.validate_credentials("u1", {"api_key": "bad"})
class TestGetSupportedCredentialTypes:
def test_empty_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
assert ctrl.get_supported_credential_types() == []
def test_oauth_only(self):
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
types = ctrl.get_supported_credential_types()
assert CredentialType.OAUTH2 in types
assert CredentialType.API_KEY not in types
def test_apikey_only(self):
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
types = ctrl.get_supported_credential_types()
assert CredentialType.API_KEY in types
assert CredentialType.OAUTH2 not in types
def test_both(self):
oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")])
ctrl = make_controller(
entity=make_provider_entity(
constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth)
)
)
types = ctrl.get_supported_credential_types()
assert CredentialType.OAUTH2 in types
assert CredentialType.API_KEY in types
class TestGetCredentialsSchema:
def test_returns_empty_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
assert ctrl.get_credentials_schema(CredentialType.API_KEY) == []
def test_returns_apikey_credentials(self):
cfg = make_provider_config("token")
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg])))
result = ctrl.get_credentials_schema(CredentialType.API_KEY)
assert len(result) == 1
assert result[0].name == "token"
def test_returns_oauth_credentials(self):
oauth_cred = make_provider_config("oauth_token")
oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
result = ctrl.get_credentials_schema(CredentialType.OAUTH2)
assert len(result) == 1
assert result[0].name == "oauth_token"
def test_unauthorized_returns_empty(self):
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == []
def test_invalid_type_raises(self):
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor()))
with pytest.raises(ValueError, match="Invalid credential type"):
ctrl.get_credentials_schema("bogus_type")
class TestGetEventParameters:
def test_returns_params_for_known_event(self):
param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING)
evt = make_event("push", parameters=[param])
ctrl = make_controller(entity=make_provider_entity(events=[evt]))
result = ctrl.get_event_parameters("push")
assert "branch" in result
assert result["branch"].name == "branch"
def test_returns_empty_for_unknown_event(self):
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
assert ctrl.get_event_parameters("nonexistent") == {}
class TestDispatch:
@patch("core.trigger.provider.PluginTriggerClient")
def test_delegates_to_client(self, mock_client):
ctrl = make_controller()
expected = MagicMock()
mock_client.return_value.dispatch_event.return_value = expected
result = ctrl.dispatch(
request=MagicMock(),
subscription=make_subscription(),
credentials={"k": "v"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
mock_client.return_value.dispatch_event.assert_called_once()
class TestInvokeTriggerEvent:
@patch("core.trigger.provider.PluginTriggerClient")
def test_delegates_to_client(self, mock_client):
ctrl = make_controller()
expected = MagicMock()
mock_client.return_value.invoke_trigger_event.return_value = expected
result = ctrl.invoke_trigger_event(
user_id="u1",
event_name="push",
parameters={},
credentials={},
credential_type=CredentialType.API_KEY,
subscription=make_subscription(),
request=MagicMock(),
payload={},
)
assert result is expected
class TestSubscribeTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_returns_validated_subscription(self, mock_client):
ctrl = make_controller()
mock_client.return_value.subscribe.return_value.subscription = {
"expires_at": 123,
"endpoint": "https://e",
"properties": {},
}
result = ctrl.subscribe_trigger(
user_id="u1",
endpoint="https://e",
parameters={},
credentials={},
credential_type=CredentialType.API_KEY,
)
assert result.endpoint == "https://e"
class TestUnsubscribeTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_returns_validated_result(self, mock_client):
ctrl = make_controller()
mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"}
result = ctrl.unsubscribe_trigger(
user_id="u1",
subscription=make_subscription(),
credentials={},
credential_type=CredentialType.API_KEY,
)
assert result.success is True
class TestRefreshTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_uses_system_user_id(self, mock_client):
ctrl = make_controller()
mock_client.return_value.refresh.return_value.subscription = {
"expires_at": 456,
"endpoint": "https://e",
"properties": {},
}
ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY)
call_kwargs = mock_client.return_value.refresh.call_args[1]
assert call_kwargs["user_id"] == "system"

View File

@ -0,0 +1,307 @@
"""
Tests for core.trigger.trigger_manager.TriggerManager.
Covers: icon URL construction, provider listing with error resilience,
double-check lock caching, error translation, EventIgnoreError -> cancelled,
and delegation to provider controller.
"""
from __future__ import annotations
from threading import Lock
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
from core.trigger.errors import EventIgnoreError
from core.trigger.trigger_manager import TriggerManager
from models.provider_ids import TriggerProviderID
from tests.unit_tests.core.trigger.conftest import (
VALID_PROVIDER_ID,
make_controller,
make_provider_entity,
make_subscription,
)
PID = TriggerProviderID(VALID_PROVIDER_ID)
PID_STR = str(PID)
class TestGetTriggerPluginIcon:
@patch("core.trigger.trigger_manager.dify_config")
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_builds_correct_url(self, mock_client, mock_config):
mock_config.CONSOLE_API_URL = "https://console.example.com"
provider = MagicMock()
provider.declaration.identity.icon = "my-icon.svg"
mock_client.return_value.fetch_trigger_provider.return_value = provider
url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID)
assert "tenant_id=tenant-1" in url
assert "filename=my-icon.svg" in url
assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon")
class TestListPluginTriggerProviders:
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_wraps_entities_into_controllers(self, mock_client):
entity = MagicMock()
entity.declaration = make_provider_entity("p1")
entity.plugin_id = "plugin-1"
entity.plugin_unique_identifier = "uid-1"
entity.provider = VALID_PROVIDER_ID
mock_client.return_value.fetch_trigger_providers.return_value = [entity]
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
assert len(controllers) == 1
assert controllers[0].plugin_id == "plugin-1"
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_skips_failing_providers(self, mock_client):
good = MagicMock()
good.declaration = make_provider_entity("good")
good.plugin_id = "good-plugin"
good.plugin_unique_identifier = "uid-good"
good.provider = VALID_PROVIDER_ID
bad = MagicMock()
bad.declaration = make_provider_entity("bad")
bad.plugin_id = "bad-plugin"
bad.plugin_unique_identifier = "uid-bad"
bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation
mock_client.return_value.fetch_trigger_providers.return_value = [bad, good]
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
assert len(controllers) == 1
assert controllers[0].plugin_id == "good-plugin"
class TestGetTriggerProvider:
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_initializes_context_on_first_call(self, mock_ctx, mock_client):
# get() called 3 times: (1) try block, (2) after set, (3) under lock
mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}]
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
provider = MagicMock()
provider.declaration = make_provider_entity()
provider.plugin_id = "p1"
provider.plugin_unique_identifier = "uid-1"
mock_client.return_value.fetch_trigger_provider.return_value = provider
result = TriggerManager.get_trigger_provider("t1", PID)
mock_ctx.plugin_trigger_providers.set.assert_called_once_with({})
mock_ctx.plugin_trigger_providers_lock.set.assert_called_once()
assert result is not None
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_returns_cached_without_fetch(self, mock_ctx, mock_client):
cached = make_controller()
mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached}
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is cached
mock_client.return_value.fetch_trigger_provider.assert_not_called()
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client):
cached = make_controller()
mock_ctx.plugin_trigger_providers.get.side_effect = [
{}, # first check misses
{PID_STR: cached}, # under-lock check hits
]
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is cached
mock_client.return_value.fetch_trigger_provider.assert_not_called()
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client):
cache: dict = {}
mock_ctx.plugin_trigger_providers.get.return_value = cache
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
provider = MagicMock()
provider.declaration = make_provider_entity()
provider.plugin_id = "p1"
provider.plugin_unique_identifier = "uid-1"
mock_client.return_value.fetch_trigger_provider.return_value = provider
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is not None
assert PID_STR in cache
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_none_fetch_raises_value_error(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.return_value = None
with pytest.raises(ValueError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing"))
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone")
with pytest.raises(ValueError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error")
with pytest.raises(PluginDaemonError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
class TestListAllTriggerProviders:
@patch.object(TriggerManager, "list_plugin_trigger_providers")
def test_delegates_to_list_plugin(self, mock_list):
expected = [make_controller()]
mock_list.return_value = expected
assert TriggerManager.list_all_trigger_providers("t1") is expected
mock_list.assert_called_once_with("t1")
class TestListTriggersByProvider:
@patch.object(TriggerManager, "get_trigger_provider")
def test_returns_provider_events(self, mock_get):
ctrl = make_controller()
mock_get.return_value = ctrl
result = TriggerManager.list_triggers_by_provider("t1", PID)
assert result == ctrl.get_events()
class TestInvokeTriggerEvent:
def _args(self):
return {
"tenant_id": "t1",
"user_id": "u1",
"provider_id": PID,
"event_name": "on_push",
"parameters": {"branch": "main"},
"credentials": {"token": "abc"},
"credential_type": CredentialType.API_KEY,
"subscription": make_subscription(),
"request": MagicMock(),
"payload": {"action": "push"},
}
@patch.object(TriggerManager, "get_trigger_provider")
def test_returns_invoke_response(self, mock_get):
ctrl = MagicMock()
expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False)
ctrl.invoke_trigger_event.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.invoke_trigger_event(**self._args())
assert result is expected
assert result.cancelled is False
@patch.object(TriggerManager, "get_trigger_provider")
def test_event_ignore_returns_cancelled(self, mock_get):
ctrl = MagicMock()
ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip")
mock_get.return_value = ctrl
result = TriggerManager.invoke_trigger_event(**self._args())
assert result.cancelled is True
assert result.variables == {}
@patch.object(TriggerManager, "get_trigger_provider")
def test_other_errors_propagate(self, mock_get):
ctrl = MagicMock()
ctrl.invoke_trigger_event.side_effect = RuntimeError("boom")
mock_get.return_value = ctrl
with pytest.raises(RuntimeError, match="boom"):
TriggerManager.invoke_trigger_event(**self._args())
class TestSubscribeTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = make_subscription()
ctrl.subscribe_trigger.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.subscribe_trigger(
tenant_id="t1",
user_id="u1",
provider_id=PID,
endpoint="https://hook.test",
parameters={"f": "all"},
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
ctrl.subscribe_trigger.assert_called_once()
class TestUnsubscribeTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = MagicMock()
ctrl.unsubscribe_trigger.return_value = expected
mock_get.return_value = ctrl
sub = make_subscription()
result = TriggerManager.unsubscribe_trigger(
tenant_id="t1",
user_id="u1",
provider_id=PID,
subscription=sub,
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
class TestRefreshTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = make_subscription()
ctrl.refresh_trigger.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.refresh_trigger(
tenant_id="t1",
provider_id=PID,
subscription=make_subscription(),
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected

View File

@ -0,0 +1,62 @@
"""Tests for core.trigger.utils.encryption — masking logic and cache key generation."""
from __future__ import annotations
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.trigger.utils.encryption import (
TriggerProviderCredentialsCache,
TriggerProviderOAuthClientParamsCache,
TriggerProviderPropertiesCache,
masked_credentials,
)
def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig:
return ProviderConfig(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
type=field_type,
)
class TestMaskedCredentials:
def test_short_secret_fully_masked(self):
schema = [_make_schema("key", "secret-input")]
result = masked_credentials(schema, {"key": "ab"})
assert result["key"] == "**"
def test_long_secret_partially_masked(self):
schema = [_make_schema("key", "secret-input")]
result = masked_credentials(schema, {"key": "abcdef"})
assert result["key"].startswith("ab")
assert result["key"].endswith("ef")
assert "**" in result["key"]
def test_non_secret_field_unchanged(self):
schema = [_make_schema("host", "text-input")]
result = masked_credentials(schema, {"host": "example.com"})
assert result["host"] == "example.com"
def test_unknown_key_passes_through(self):
result = masked_credentials([], {"unknown": "value"})
assert result["unknown"] == "value"
class TestCacheKeyGeneration:
def test_credentials_cache_key_contains_ids(self):
cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
assert "c1" in cache.cache_key
def test_oauth_client_cache_key_contains_ids(self):
cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
def test_properties_cache_key_contains_ids(self):
cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
assert "s1" in cache.cache_key

View File

@ -0,0 +1,31 @@
"""Tests for core.trigger.utils.endpoint — URL generation."""
from __future__ import annotations
from unittest.mock import patch
from yarl import URL
from core.trigger.utils import endpoint
class TestGeneratePluginTriggerEndpointUrl:
def test_builds_correct_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123")
assert url == "https://api.example.com/triggers/plugin/endpoint-123"
class TestGenerateWebhookTriggerEndpoint:
def test_non_debug_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False)
assert url == "https://api.example.com/triggers/webhook/sub-456"
def test_debug_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True)
assert url == "https://api.example.com/triggers/webhook-debug/sub-456"

View File

@ -0,0 +1,23 @@
"""Tests for core.trigger.utils.locks — Redis lock key builders."""
from __future__ import annotations
from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys
class TestBuildTriggerRefreshLockKey:
def test_correct_format(self):
key = build_trigger_refresh_lock_key("tenant-1", "sub-1")
assert key == "trigger_provider_refresh_lock:tenant-1_sub-1"
class TestBuildTriggerRefreshLockKeys:
def test_maps_over_pairs(self):
pairs = [("t1", "s1"), ("t2", "s2")]
keys = build_trigger_refresh_lock_keys(pairs)
assert len(keys) == 2
assert keys[0] == "trigger_provider_refresh_lock:t1_s1"
assert keys[1] == "trigger_provider_refresh_lock:t2_s2"