test: split merged API test modules and remove F811 ignore (#35105)

This commit is contained in:
99 2026-04-14 11:54:30 +08:00 committed by GitHub
parent 178883b4cc
commit 28185170b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2214 additions and 2544 deletions

View File

@ -97,7 +97,6 @@ ignore = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]

View File

@ -95,30 +95,6 @@ class TestTextToAudioPayload:
assert payload.streaming is True
# ---------------------------------------------------------------------------
# AudioService Interface Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test AudioService method interfaces exist."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
# ---------------------------------------------------------------------------
# Audio Service Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test suite for AudioService interface methods."""

View File

@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi:
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()

View File

@ -73,11 +73,6 @@ class TestAsyncWorkflowService:
mock_dispatcher = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
with (
patch.object(

View File

@ -0,0 +1,602 @@
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
class TestModelProviderServiceConfiguration:
def test__get_provider_configuration_should_return_configuration_when_provider_exists(self) -> None:
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status(self) -> None:
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context(self) -> None:
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
class TestModelProviderServiceDelegation:
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_name": "A",
},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
self,
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
result = getattr(service, method_name)(**method_kwargs)
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
self,
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
result = getattr(service, method_name)(**method_kwargs)
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
class TestModelProviderServiceListingsAndDefaults:
def test_get_models_by_model_type_should_group_active_non_deprecated_models(self) -> None:
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
self,
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
)
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager(self) -> None:
service, manager = _create_service_with_mocked_manager()
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
self,
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@ -85,644 +85,3 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
assert len(custom_models) == 1
# The sanitizer should drop credentials in list response
assert custom_models[0].credentials is None
# === Merged from test_model_provider_service.py ===
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
# Assert
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
# Act / Assert
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
# Act
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
# Assert
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
# Assert
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
# Act
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
# Assert
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
# Act
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
# Assert
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
# Assert
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
# Assert
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@ -12,7 +12,6 @@ This test suite covers all functionality of the current VariableTruncator includ
import functools
import json
import uuid
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
@ -674,229 +673,3 @@ def test_dummy_variable_truncator_methods():
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
# === Merged from test_variable_truncator_additional.py ===
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
# Arrange / Act
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
# Arrange / Act
return super().truncate_variable_mapping(v) # type: ignore[misc]
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
# Arrange
passthrough = _AbstractPassthrough()
# Act
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
# Assert
assert truncate_result is None
assert mapping_result is None
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
# Act
truncator = VariableTruncator.default()
# Assert
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
# Arrange
# Act
result = VariableTruncator._json_value_needs_truncation(value)
# Assert
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(StringSegment(value="input"))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values() -> None:
# Arrange
segment = StringSegment(value="abc")
# Act
size = VariableTruncator.calculate_json_size(segment)
# Assert
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
# Arrange
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
size = VariableTruncator.calculate_json_size(updated)
# Assert
assert size > 0
def test_maybe_qa_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
)
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
# Arrange
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
# Act
result = truncator._truncate_object(mapping, 20)
# Assert
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
result = truncator._truncate_json_primitives(updated, 100)
# Assert
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
# Arrange
truncator = VariableTruncator()
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@ -0,0 +1,174 @@
from collections.abc import Mapping
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
return super().truncate_variable_mapping(v) # type: ignore[misc]
class TestBaseTruncatorContract:
def test_base_truncator_methods_should_execute_abstract_placeholders(self) -> None:
passthrough = _AbstractPassthrough()
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
assert truncate_result is None
assert mapping_result is None
class TestVariableTruncatorAdditionalBehavior:
def test_default_should_use_dify_config_limits(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
truncator = VariableTruncator.default()
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis(self) -> None:
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
result, truncated = truncator.truncate_variable_mapping(mapping)
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values(self) -> None:
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
result, truncated = truncator.truncate_variable_mapping(mapping)
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(
self,
value: Any,
expected: bool,
) -> None:
result = VariableTruncator._json_value_needs_truncation(value)
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
result = truncator.truncate(StringSegment(value="input"))
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values(self) -> None:
segment = StringSegment(value="abc")
size = VariableTruncator.calculate_json_size(segment)
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances(self) -> None:
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
size = VariableTruncator.calculate_json_size(updated)
assert size > 0
def test_maybe_qa_structure_should_validate_shape(self) -> None:
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape(self) -> None:
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
)
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"})
is False
)
def test_truncate_object_should_truncate_segment_values_inside_object(self) -> None:
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
result = truncator._truncate_object(mapping, 20)
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input(self) -> None:
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
result = truncator._truncate_json_primitives(updated, 100)
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type(self) -> None:
truncator = VariableTruncator()
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@ -559,771 +559,3 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@ -0,0 +1,671 @@
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
class TestWebhookServiceLookup:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1",
is_debug=True,
)
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
class TestWebhookServiceExtractionFallbacks:
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(self, flask_app: Flask) -> None:
webhook_trigger = MagicMock()
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = MagicMock()
monkeypatch.setattr(
WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")
)
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
result = WebhookService._detect_binary_mimetype(b"binary")
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
class TestWebhookServiceValidationAndConversion:
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
self,
raw_value: str,
param_type: str,
expected: Any,
) -> None:
result = WebhookService._convert_form_value("param", raw_value, param_type)
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type(self) -> None:
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors(self) -> None:
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing(self) -> None:
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters(self) -> None:
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing(self) -> None:
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data(self) -> None:
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names(self) -> None:
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
WebhookService._validate_required_headers(headers, configs)
def test_validate_required_headers_should_raise_when_required_header_missing(self) -> None:
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error(self) -> None:
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
result = WebhookService._validate_http_metadata(webhook_data, node_data)
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key(self) -> None:
headers = {"content-type": "application/json; charset=utf-8"}
assert WebhookService._extract_content_type(headers) == "application/json"
def test_build_workflow_inputs_should_include_expected_keys(self) -> None:
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
result = WebhookService.build_workflow_inputs(webhook_data)
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
class TestWebhookServiceExecutionAndSync:
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=end_user),
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
)
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(side_effect=RuntimeError("boom")),
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
WebhookService.sync_webhook_relationships(app, workflow)
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
WebhookService.sync_webhook_relationships(app, workflow)
assert logger_exception_mock.call_count == 1
class TestWebhookServiceUtilities:
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
body, status = WebhookService.generate_webhook_response(node_config)
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier(self) -> None:
webhook_id = WebhookService.generate_webhook_id()
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input(self) -> None:
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
assert result == 123

View File

@ -0,0 +1,262 @@
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from sqlalchemy import Engine
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
class TestWorkflowRunServiceInitialization:
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
self,
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
service = WorkflowRunService()
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
self,
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
service = WorkflowRunService(session_factory=engine)
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
service = WorkflowRunService(session_factory=session_factory)
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
class TestWorkflowRunServiceQueries:
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@ -176,300 +176,3 @@ class TestWorkflowRunService:
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
# === Merged from test_workflow_run_service.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
# Arrange
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
# Assert
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
# Act
service = WorkflowRunService()
# Assert
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
# Act
service = WorkflowRunService(session_factory=engine)
# Assert
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
# Act
service = WorkflowRunService(session_factory=session_factory)
# Assert
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
# Act
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
# Act
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
# Assert
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
# Act
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
# Assert
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
# Act
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
# Act / Assert
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@ -3,7 +3,6 @@ import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
import pytest
@ -223,577 +222,3 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected
# === Merged from test_workflow_event_snapshot_service_additional.py ===
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
def test_get_message_context_should_return_none_when_no_message() -> None:
# Arrange
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
# Arrange
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
# Arrange
# Act
result = service_module._load_resumption_context(None)
# Assert
assert result is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
# Arrange
pause_entity = _PauseEntity(state=b"not-a-valid-state")
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is None
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
# Arrange
context = _build_resumption_context_additional(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
# Arrange
# Act
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
# Assert
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
# Arrange
# Act
result = service_module._parse_event_message(payload)
# Assert
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
# Arrange
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
# Act
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
# Assert
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
# Arrange
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
# Act
service_module._apply_message_context(payload, context)
# Assert
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
# Arrange
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
# Arrange
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert finished is True
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
# Assert
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None

View File

@ -0,0 +1,505 @@
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
class TestWorkflowEventSnapshotHelpers:
def test_get_message_context_should_return_none_when_no_message(self) -> None:
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp(self) -> None:
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing(self) -> None:
assert service_module._load_resumption_context(None) is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid(self) -> None:
pause_entity = _PauseEntity(state=b"not-a-valid-state")
assert service_module._load_resumption_context(pause_entity) is None
def test_load_resumption_context_should_parse_valid_state_into_context(self) -> None:
context = _build_resumption_context(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
result = service_module._load_resumption_context(pause_entity)
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing(self) -> None:
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
self,
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
result = service_module._parse_event_message(payload)
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events(self) -> None:
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists(self) -> None:
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
service_module._apply_message_context(payload, context)
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event(self) -> None:
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises(self) -> None:
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
assert buffer_state.done_event.wait(timeout=1) is True
class TestBuildWorkflowEventStream:
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None