mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
test: split merged API test modules and remove F811 ignore (#35105)
This commit is contained in:
parent
178883b4cc
commit
28185170b0
@ -97,7 +97,6 @@ ignore = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
]
|
||||
|
||||
@ -95,30 +95,6 @@ class TestTextToAudioPayload:
|
||||
assert payload.streaming is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioService Interface Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test AudioService method interfaces exist."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio Service Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test suite for AudioService interface methods."""
|
||||
|
||||
|
||||
@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi:
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
|
||||
@ -73,11 +73,6 @@ class TestAsyncWorkflowService:
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
|
||||
602
api/tests/unit_tests/services/test_model_provider_service.py
Normal file
602
api/tests/unit_tests/services/test_model_provider_service.py
Normal file
@ -0,0 +1,602 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProviderServiceConfiguration:
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
assert result is provider_configuration
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderServiceDelegation:
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_name": "A",
|
||||
},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
self,
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
self,
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
class TestModelProviderServiceListingsAndDefaults:
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
self,
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
)
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
self,
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
@ -85,644 +85,3 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
|
||||
|
||||
# === Merged from test_model_provider_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert result is provider_configuration
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
# Act
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
# Assert
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
# Act
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Assert
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
# Act
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
# Assert
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
@ -12,7 +12,6 @@ This test suite covers all functionality of the current VariableTruncator includ
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@ -674,229 +673,3 @@ def test_dummy_variable_truncator_methods():
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
# === Merged from test_variable_truncator_additional.py ===
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
# Arrange / Act
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
# Arrange / Act
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
|
||||
# Arrange
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
# Act
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
# Assert
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
# Act
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
# Assert
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values() -> None:
|
||||
# Arrange
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
# Assert
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
|
||||
# Arrange
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
# Assert
|
||||
assert size > 0
|
||||
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
|
||||
)
|
||||
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
@ -0,0 +1,174 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
class TestBaseTruncatorContract:
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders(self) -> None:
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
class TestVariableTruncatorAdditionalBehavior:
|
||||
def test_default_should_use_dify_config_limits(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(
|
||||
self,
|
||||
value: Any,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
assert result is expected
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values(self) -> None:
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances(self) -> None:
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
assert size > 0
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape(self) -> None:
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape(self) -> None:
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
)
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"})
|
||||
is False
|
||||
)
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object(self) -> None:
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type(self) -> None:
|
||||
truncator = VariableTruncator()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
@ -559,771 +559,3 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
671
api/tests/unit_tests/services/test_webhook_service_additional.py
Normal file
671
api/tests/unit_tests/services/test_webhook_service_additional.py
Normal file
@ -0,0 +1,671 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
class TestWebhookServiceLookup:
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1",
|
||||
is_debug=True,
|
||||
)
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
class TestWebhookServiceExtractionFallbacks:
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(self, flask_app: Flask) -> None:
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")
|
||||
)
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceValidationAndConversion:
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
self,
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
assert result == expected
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type(self) -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors(self) -> None:
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing(self) -> None:
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters(self) -> None:
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing(self) -> None:
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data(self) -> None:
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names(self) -> None:
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing(self) -> None:
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error(self) -> None:
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key(self) -> None:
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
assert WebhookService._extract_content_type(headers) == "application/json"
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys(self) -> None:
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
class TestWebhookServiceExecutionAndSync:
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=end_user),
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
class TestWebhookServiceUtilities:
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier(self) -> None:
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input(self) -> None:
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
assert result == 123
|
||||
262
api/tests/unit_tests/services/test_workflow_run_service.py
Normal file
262
api/tests/unit_tests/services/test_workflow_run_service.py
Normal file
@ -0,0 +1,262 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
class TestWorkflowRunServiceInitialization:
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
service = WorkflowRunService()
|
||||
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
class TestWorkflowRunServiceQueries:
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
@ -176,300 +176,3 @@ class TestWorkflowRunService:
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
|
||||
|
||||
# === Merged from test_workflow_run_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
# Arrange
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
|
||||
# Assert
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService()
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
# Assert
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
# Assert
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
@ -3,7 +3,6 @@ import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
@ -223,577 +222,3 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
# === Merged from test_workflow_event_snapshot_service_additional.py ===
|
||||
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def test_get_message_context_should_return_none_when_no_message() -> None:
|
||||
# Arrange
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
|
||||
# Arrange
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(None)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
|
||||
# Arrange
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
|
||||
# Arrange
|
||||
context = _build_resumption_context_additional(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._parse_event_message(payload)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
|
||||
# Arrange
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
# Act
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
# Assert
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
|
||||
# Arrange
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
# Act
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
# Assert
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert finished is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
|
||||
@ -0,0 +1,505 @@
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
class TestWorkflowEventSnapshotHelpers:
|
||||
def test_get_message_context_should_return_none_when_no_message(self) -> None:
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp(self) -> None:
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing(self) -> None:
|
||||
assert service_module._load_resumption_context(None) is None
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid(self) -> None:
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
assert service_module._load_resumption_context(pause_entity) is None
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context(self) -> None:
|
||||
context = _build_resumption_context(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing(self) -> None:
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
assert result == "run-1"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
self,
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
result = service_module._parse_event_message(payload)
|
||||
assert result == expected
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events(self) -> None:
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists(self) -> None:
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event(self) -> None:
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises(self) -> None:
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
|
||||
assert buffer_state.done_event.wait(timeout=1) is True
|
||||
|
||||
|
||||
class TestBuildWorkflowEventStream:
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
Loading…
Reference in New Issue
Block a user