diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 2adaf14a35..a881fba877 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): """ """ + def __init__(self) -> None: + super().__init__() + self._paused = False + def on_graph_start(self): - pass + self._paused = False def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. """ if isinstance(event, GraphRunPausedEvent): - pass + self._paused = True def on_graph_end(self, error: Exception | None): """ """ - pass + self._paused = False + + def is_paused(self) -> bool: + return self._paused diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 99b64b3ab5..d95a378575 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._handle_graph_run_paused(event) return - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - if isinstance(event, NodeRunRetryEvent): self._handle_node_retry(event) return + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + if isinstance(event, NodeRunSucceededEvent): self._handle_node_succeeded(event) return diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada31..818ae46625 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -68,7 +68,7 @@ def init_tool_node(config: dict): return node -def test_tool_variable_invoke(): +def test_tool_variable_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", @@ -103,7 +103,7 @@ def test_tool_variable_invoke(): assert item.node_run_result.outputs.get("text") is not None -def test_tool_mixed_invoke(): +def test_tool_mixed_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", diff --git a/api/tests/unit_tests/core/app/app_config/__init__.py b/api/tests/unit_tests/core/app/app_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py new file mode 100644 index 0000000000..1c5b6ed944 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py @@ -0,0 +1,227 @@ +from unittest.mock import MagicMock + +import pytest + +# Module under test +from core.app.app_config.common import parameters_mapping + + +class TestGetParametersFromFeatureDict: + """Test suite for get_parameters_from_feature_dict""" + + @pytest.fixture + def mock_config(self, monkeypatch): + """Mock dify_config values""" + mock = MagicMock() + mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1 + mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2 + mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3 + mock.UPLOAD_FILE_SIZE_LIMIT = 4 + mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5 + + monkeypatch.setattr(parameters_mapping, "dify_config", mock) + return mock + + @pytest.fixture + def mock_default_file_limits(self, monkeypatch): + """Mock DEFAULT_FILE_NUMBER_LIMITS constant""" + monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99) + return 99 + + @pytest.fixture + def minimal_inputs(self): + return {}, [] + + @pytest.mark.parametrize( + ("feature_key", "expected_default"), + [ + ("suggested_questions", []), + ("suggested_questions_after_answer", {"enabled": False}), + ("speech_to_text", {"enabled": False}), + ("text_to_speech", {"enabled": False}), + ("retriever_resource", {"enabled": False}), + ("annotation_reply", {"enabled": False}), + ("more_like_this", {"enabled": False}), + ( + "sensitive_word_avoidance", + {"enabled": False, "type": "", "configs": []}, + ), + ], + ) + def test_defaults_when_key_missing( + self, + feature_key, + expected_default, + mock_config, + mock_default_file_limits, + ): + # Arrange + features = {} + user_input = [] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + assert result[feature_key] == expected_default + + def test_opening_statement_present(self, mock_config, mock_default_file_limits): + # Arrange + features = {"opening_statement": "Hello"} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] == "Hello" + + def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] is None + + def test_all_features_provided(self, mock_config, mock_default_file_limits): + # Arrange + features = { + "opening_statement": "Hi", + "suggested_questions": ["Q1"], + "suggested_questions_after_answer": {"enabled": True}, + "speech_to_text": {"enabled": True}, + "text_to_speech": {"enabled": True}, + "retriever_resource": {"enabled": True}, + "annotation_reply": {"enabled": True}, + "more_like_this": {"enabled": True}, + "sensitive_word_avoidance": { + "enabled": True, + "type": "strict", + "configs": ["a"], + }, + "file_upload": { + "image": { + "enabled": True, + "number_limits": 10, + "detail": "low", + "transfer_methods": ["local_file"], + } + }, + } + user_input = [{"name": "field1"}] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + for key in features: + assert result[key] == features[key] + assert result["user_input_form"] == user_input + + def test_file_upload_default_structure(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + file_upload = result["file_upload"] + assert file_upload["image"]["enabled"] is False + assert file_upload["image"]["number_limits"] == 99 + assert file_upload["image"]["detail"] == "high" + assert "remote_url" in file_upload["image"]["transfer_methods"] + assert "local_file" in file_upload["image"]["transfer_methods"] + + def test_system_parameters_from_config(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + system_params = result["system_parameters"] + assert system_params["image_file_size_limit"] == 1 + assert system_params["video_file_size_limit"] == 2 + assert system_params["audio_file_size_limit"] == 3 + assert system_params["file_size_limit"] == 4 + assert system_params["workflow_file_upload_limit"] == 5 + + @pytest.mark.parametrize( + ("features_dict", "user_input_form"), + [ + (None, []), + ([], []), + ("invalid", []), + ], + ) + def test_invalid_features_dict_type_raises(self, features_dict, user_input_form): + # Act & Assert + with pytest.raises(AttributeError): + parameters_mapping.get_parameters_from_feature_dict( + features_dict=features_dict, + user_input_form=user_input_form, + ) + + @pytest.mark.parametrize( + "user_input_form", + [None, "invalid", 123], + ) + def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input_form, + ) + + # Assert + assert result["user_input_form"] == user_input_form + + def test_empty_user_input_form(self, mock_config, mock_default_file_limits): + features = {} + user_input = [] + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + assert result["user_input_form"] == [] + + def test_feature_values_none(self, mock_config, mock_default_file_limits): + features = { + "suggested_questions": None, + "speech_to_text": None, + } + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + assert result["suggested_questions"] is None + assert result["speech_to_text"] is None diff --git a/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py new file mode 100644 index 0000000000..013ed0cbc4 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.common.sensitive_word_avoidance.manager import ( + SensitiveWordAvoidanceConfigManager, +) + + +class TestSensitiveWordAvoidanceConfigManagerConvert: + """Tests for convert classmethod""" + + @pytest.mark.parametrize( + "config", + [ + {}, + {"sensitive_word_avoidance": None}, + {"sensitive_word_avoidance": {}}, + {"sensitive_word_avoidance": {"enabled": False}}, + ], + ) + def test_convert_returns_none_when_disabled_or_missing(self, config): + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result is None + + def test_convert_returns_entity_when_enabled(self, mocker): + # Arrange + mock_entity = MagicMock() + mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"key": "value"}, + } + } + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result == mock_entity + + def test_convert_enabled_without_type_or_config(self, mocker): + # Arrange + mock_entity = MagicMock() + patched = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = {"sensitive_word_avoidance": {"enabled": True}} + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + patched.assert_called_once_with(type=None, config={}) + assert result == mock_entity + + +class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults: + """Tests for validate_and_set_defaults classmethod""" + + @pytest.fixture + def base_config(self): + return {} + + def test_validate_sets_default_when_missing(self, base_config): + # Act + config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=base_config.copy() + ) + + # Assert + assert config["sensitive_word_avoidance"]["enabled"] is False + assert fields == ["sensitive_word_avoidance"] + + def test_validate_raises_when_not_dict(self): + config = {"sensitive_word_avoidance": "invalid"} + + with pytest.raises(ValueError, match="must be of dict type"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + @pytest.mark.parametrize( + "config", + [ + {"sensitive_word_avoidance": {"enabled": False}}, + {"sensitive_word_avoidance": {"enabled": None}}, + {"sensitive_word_avoidance": {}}, + ], + ) + def test_validate_disables_when_enabled_false_or_missing(self, config): + # Act + result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + assert result_config["sensitive_word_avoidance"]["enabled"] is False + + def test_validate_raises_when_enabled_true_without_type(self): + config = {"sensitive_word_avoidance": {"enabled": True}} + + with pytest.raises(ValueError, match="type is required"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_type_not_string(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": 123, + } + } + + with pytest.raises(ValueError, match="must be a string"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_config_not_dict(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": "invalid", + } + } + + with pytest.raises(ValueError, match="must be a dict"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_calls_moderation_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"}) + assert result_config["sensitive_word_avoidance"]["enabled"] is True + assert fields == ["sensitive_word_avoidance"] + + def test_validate_sets_empty_dict_when_config_none(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": None, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={}) + + def test_validate_only_structure_validate_skips_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config, only_structure_validate=True + ) + + # Assert + mock_validate.assert_not_called() diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py new file mode 100644 index 0000000000..992b580376 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager + + +class TestAgentConfigManagerConvert: + @pytest.fixture + def base_config(self): + return { + "agent_mode": { + "enabled": True, + "strategy": "cot", + "tools": [], + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "completion", + }, + } + + def test_convert_returns_none_when_agent_mode_missing(self): + config = {"model": {"provider": "openai", "name": "gpt-4"}} + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}]) + def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config): + config = base_config.copy() + config["agent_mode"] = agent_mode_value + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize( + ("strategy_input", "expected_enum"), + [ + ("function_call", "FUNCTION_CALLING"), + ("cot", "CHAIN_OF_THOUGHT"), + ("react", "CHAIN_OF_THOUGHT"), + ], + ) + def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": strategy_input, + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.strategy.name == expected_enum + + def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "openai" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "FUNCTION_CALLING" + + def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "anthropic" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "CHAIN_OF_THOUGHT" + + def test_convert_skips_disabled_tools(self, mocker, base_config): + # Patch AgentEntity to bypass pydantic validation + mock_agent_entity = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity", + return_value=MagicMock(), + ) + + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value={ + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "tool_parameters": {}, + "credential_id": None, + }, + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + { + "provider_type": "type1", + "provider_id": "id1", + "tool_name": "tool1", + "enabled": False, + }, + { + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "enabled": True, + "extra_key": "x", + }, + ], + } + + AgentConfigManager.convert(config) + + mock_validate.assert_called_once() + mock_agent_entity.assert_called_once() + + def test_convert_tool_requires_minimum_keys(self, mocker, base_config): + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value=MagicMock(), + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + {"a": 1, "b": 2}, # insufficient keys + ], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.tools == [] + mock_validate.assert_not_called() + + def test_convert_completion_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "completion" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_chat_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "chat" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_react_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "react_router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_max_iteration_default(self, base_config): + config = base_config.copy() + config["agent_mode"].pop("max_iteration", None) + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 10 + + def test_convert_custom_max_iteration(self, base_config): + config = base_config.copy() + config["agent_mode"]["max_iteration"] = 25 + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 25 + + def test_convert_missing_model_raises_key_error(self, base_config): + config = base_config.copy() + del config["model"] + + with pytest.raises(KeyError): + AgentConfigManager.convert(config) + + @pytest.mark.parametrize( + ("invalid_config", "should_raise"), + [ + (None, True), + (123, True), + ("", False), + ([], False), + ], + ) + def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise): + if should_raise: + with pytest.raises(TypeError): + AgentConfigManager.convert(invalid_config) # type: ignore + else: + result = AgentConfigManager.convert(invalid_config) # type: ignore + assert result is None diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py new file mode 100644 index 0000000000..a688e2a5c5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -0,0 +1,319 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def valid_uuid(): + return str(uuid.uuid4()) + + +@pytest.fixture +def base_config(valid_uuid): + return { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + + +@pytest.fixture +def mock_dataset_service(mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + +# ============================== +# convert tests +# ============================== + + +class TestDatasetConfigManagerConvert: + def test_convert_returns_none_when_no_datasets(self): + config = {"dataset_configs": {"datasets": {"datasets": []}}} + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_single_retrieval(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result is not None + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable == "query" + + def test_convert_single_with_metadata_configs(self, valid_uuid, mocker): + mock_retrieve_config = MagicMock() + mock_entity = MagicMock() + mock_entity.dataset_ids = [valid_uuid] + mock_entity.retrieve_config = mock_retrieve_config + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig", + return_value={"mock": "model"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition", + return_value={"mock": "condition"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity", + return_value=mock_retrieve_config, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity", + return_value=mock_entity, + ) + + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "metadata_filtering_mode": "manual", + "metadata_model_config": {"any": "value"}, + "metadata_filtering_conditions": {"any": "value"}, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config is mock_retrieve_config + + def test_convert_multiple_defaults(self, valid_uuid): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 4 + assert result.retrieve_config.score_threshold is None + assert result.retrieve_config.reranking_enabled is True + + def test_convert_agent_mode_disabled_tool(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": False}}], + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_dataset_configs_none(self): + config = {"dataset_configs": None} + with pytest.raises(TypeError): + DatasetConfigManager.convert(config) + + def test_convert_agent_mode_old_style_old_format(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable is None + + def test_convert_multiple_with_score_threshold(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "multiple", + "top_k": 5, + "score_threshold": 0.8, + "score_threshold_enabled": True, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 5 + assert result.retrieve_config.score_threshold == 0.8 + + @pytest.mark.parametrize( + "dataset_entry", + [ + {}, + {"invalid": {}}, + {"dataset": {"id": None, "enabled": True}}, + {"dataset": {"id": "", "enabled": False}}, + ], + ) + def test_convert_ignores_invalid_dataset_entries(self, dataset_entry): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": {"strategy": "router", "datasets": [dataset_entry]}, + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_agent_mode_old_style(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + + +# ============================== +# validate_and_set_defaults tests +# ============================== + + +class TestValidateAndSetDefaults: + def test_validate_sets_defaults(self): + config = {} + updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + assert "dataset_configs" in updated + assert updated["dataset_configs"]["retrieval_model"] == "single" + assert isinstance(fields, list) + + def test_validate_raises_when_dataset_configs_not_dict(self): + config = {"dataset_configs": "invalid"} + with pytest.raises(AttributeError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + + def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid): + config = { + "dataset_configs": { + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + } + with pytest.raises(ValueError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfig: + def test_extract_sets_defaults(self): + config = {} + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + assert "agent_mode" in result + assert result["agent_mode"]["enabled"] is False + assert result["agent_mode"]["tools"] == [] + + def test_extract_invalid_agent_mode_type(self): + config = {"agent_mode": "invalid"} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_enabled_type(self): + config = {"agent_mode": {"enabled": "yes"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_tools_type(self): + config = {"agent_mode": {"enabled": True, "tools": "invalid"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_uuid(self, mocker): + invalid_uuid = "not-a-uuid" + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_dataset_not_exists(self, valid_uuid, mocker): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + +# ============================== +# is_dataset_exists tests +# ============================== + + +class TestIsDatasetExists: + def test_dataset_exists_true(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "other" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py new file mode 100644 index 0000000000..aed1651511 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.entities.model_entities import ModelStatus +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey + + +class TestModelConfigConverter: + @pytest.fixture(autouse=True) + def patch_response_entity(self, mocker): + """ + Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation + and return a simple namespace object instead. + """ + + def _factory(**kwargs): + return SimpleNamespace(**kwargs) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity", + side_effect=_factory, + ) + + @pytest.fixture + def mock_app_config(self): + app_config = MagicMock() + app_config.tenant_id = "tenant_1" + + model_config = MagicMock() + model_config.provider = "openai" + model_config.model = "gpt-4" + model_config.parameters = {"temperature": 0.5} + model_config.mode = None + + app_config.model = model_config + return app_config + + @pytest.fixture + def mock_provider_bundle(self): + bundle = MagicMock() + + # configuration + configuration = MagicMock() + configuration.provider.provider = "openai" + configuration.get_current_credentials.return_value = {"api_key": "key"} + + provider_model = MagicMock() + provider_model.status = ModelStatus.ACTIVE + configuration.get_provider_model.return_value = provider_model + + bundle.configuration = configuration + + # model type instance + model_type_instance = MagicMock() + model_schema = MagicMock() + model_schema.model_properties = {} + model_type_instance.get_model_schema.return_value = model_schema + bundle.model_type_instance = model_type_instance + + return bundle + + @pytest.fixture + def patch_provider_manager(self, mocker, mock_provider_bundle): + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + return mock_manager + + # ============================= + # Positive Scenarios + # ============================= + + def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager): + result = ModelConfigConverter.convert(mock_app_config) + + assert result.provider == "openai" + assert result.model == "gpt-4" + assert result.mode == LLMMode.CHAT + assert result.parameters == {"temperature": 0.5} + assert result.stop == [] + + def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager): + mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]} + + result = ModelConfigConverter.convert(mock_app_config) + + assert result.parameters == {"temperature": 0.7} + assert result.stop == ["\n"] + + def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker): + mock_app_config.model.mode = None + + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: LLMMode.COMPLETION.value + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.COMPLETION + + def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: "invalid" + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.CHAT + + # ============================= + # Credential Errors + # ============================= + + def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_current_credentials.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ProviderTokenNotInitError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Provider Model Errors + # ============================= + + def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_provider_model.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError), + (ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError), + (ModelStatus.QUOTA_EXCEEDED, QuotaExceededError), + ], + ) + def test_convert_provider_model_status_errors( + self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception + ): + mock_provider = MagicMock() + mock_provider.status = status + mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(expected_exception): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Schema Errors + # ============================= + + def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Edge Cases + # ============================= + + @pytest.mark.parametrize( + "parameters", + [ + {}, + {"stop": []}, + {"stop": ["END"], "max_tokens": 100}, + ], + ) + def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters): + mock_app_config.model.parameters = parameters.copy() + + result = ModelConfigConverter.convert(mock_app_config) + + if "stop" in parameters: + assert result.stop == parameters.get("stop") + expected_params = parameters.copy() + expected_params.pop("stop", None) + assert result.parameters == expected_params + else: + assert result.stop == [] + assert result.parameters == parameters diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py new file mode 100644 index 0000000000..e2ba276d8e --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -0,0 +1,230 @@ +from unittest.mock import MagicMock + +import pytest + +# Target +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def valid_completion_params(): + return {"temperature": 0.7, "stop": ["\n"]} + + +@pytest.fixture +def valid_model_list(): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {"mode": "chat"} + return [model] + + +@pytest.fixture +def provider_entities(): + provider = MagicMock() + provider.provider = "openai/gpt" + return [provider] + + +@pytest.fixture +def valid_config(): + return { + "model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}} + } + + +# ----------------------------- +# Test Class +# ----------------------------- + + +class TestModelConfigManager: + # ========================================================== + # convert + # ========================================================== + + def test_convert_success(self, valid_config): + result = ModelConfigManager.convert(valid_config) + + assert result.provider == "openai/gpt" + assert result.model == "gpt-4" + assert result.parameters == {"temperature": 0.5} + assert result.stop == ["END"] + + def test_convert_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.convert({}) + + def test_convert_without_stop(self): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}} + result = ModelConfigManager.convert(config) + assert result.stop == [] + assert result.parameters == {"temperature": 0.9} + + # ========================================================== + # validate_model_completion_params + # ========================================================== + + @pytest.mark.parametrize( + "invalid_cp", + [None, "string", 123, []], + ) + def test_validate_model_completion_params_invalid_type(self, invalid_cp): + with pytest.raises(ValueError, match="must be of object type"): + ModelConfigManager.validate_model_completion_params(invalid_cp) + + def test_validate_model_completion_params_default_stop(self): + cp = {"temperature": 0.2} + result = ModelConfigManager.validate_model_completion_params(cp) + assert result["stop"] == [] + + def test_validate_model_completion_params_invalid_stop_type(self): + cp = {"stop": "invalid"} + with pytest.raises(ValueError, match="must be of list type"): + ModelConfigManager.validate_model_completion_params(cp) + + def test_validate_model_completion_params_stop_length_exceeded(self): + cp = {"stop": [1, 2, 3, 4, 5]} + with pytest.raises(ValueError, match="less than 4"): + ModelConfigManager.validate_model_completion_params(cp) + + # ========================================================== + # validate_and_set_defaults + # ========================================================== + + def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) + + assert updated_config["model"]["mode"] == "chat" + assert keys == ["model"] + + def test_validate_and_set_defaults_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", {}) + + def test_validate_and_set_defaults_model_not_dict(self): + with pytest.raises(ValueError, match="object type"): + ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"}) + + def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): + config = {"model": {"name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): + config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.name is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = [] + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {} + + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model] + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + assert updated_config["model"]["mode"] == "completion" + + def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + with pytest.raises(ValueError, match="completion_params is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list): + """ + Covers branch where provider does not contain '/' and + ModelProviderID conversion is triggered (line 64). + """ + config = { + "model": { + "provider": "openai", # no slash -> triggers conversion + "name": "gpt-4", + "completion_params": {}, + } + } + + # Mock ModelProviderID to return formatted provider + mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") + mock_provider_id.return_value = "openai/gpt" + + # Mock provider factory + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + provider_entity = MagicMock() + provider_entity.provider = "openai/gpt" + mock_factory.return_value.get_providers.return_value = [provider_entity] + + # Mock provider manager + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + # Ensure conversion happened + mock_provider_id.assert_called_once_with("openai") + assert updated_config["model"]["provider"] == "openai/gpt" diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py new file mode 100644 index 0000000000..fd49072cd5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( + PromptTemplateConfigManager, +) + +# ----------------------------- +# Helpers +# ----------------------------- + + +class DummyEnumValue: + def __init__(self, value): + self.value = value + + +class DummyPromptType: + def __init__(self): + self.SIMPLE = "simple" + self.ADVANCED = "advanced" + + def value_of(self, value): + return value + + def __iter__(self): + return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + + +# ----------------------------- +# Convert Tests +# ----------------------------- + + +class TestPromptTemplateConfigManagerConvert: + def test_convert_missing_prompt_type_raises(self): + with pytest.raises(ValueError, match="prompt_type is required"): + PromptTemplateConfigManager.convert({}) + + def test_convert_simple_prompt(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mock_prompt_entity_cls.return_value = "simple_entity" + + config = {"prompt_type": "simple", "pre_prompt": "hello"} + + result = PromptTemplateConfigManager.convert(config) + + assert result == "simple_entity" + mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello") + + def test_convert_advanced_chat_valid(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of", + return_value="role_enum", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity", + return_value="chat_msg", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity", + return_value="chat_template", + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]}, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + @pytest.mark.parametrize( + "message", + [ + {"text": 123, "role": "user"}, + {"text": "hi", "role": 123}, + ], + ) + def test_convert_advanced_invalid_message_fields(self, mocker, message): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [message]}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.convert(config) + + def test_convert_advanced_completion_with_roles(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity", + return_value="completion_template", + ) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "complete"}, + "conversation_histories_role": { + "user_prefix": "U", + "assistant_prefix": "A", + }, + }, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + +# ----------------------------- +# validate_and_set_defaults +# ----------------------------- + + +class TestValidateAndSetDefaults: + def setup_method(self): + self.valid_model = {"mode": "chat"} + + def _patch_prompt_type(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + return mock_prompt_entity_cls + + def test_default_prompt_type_set(self, mocker): + self._patch_prompt_type(mocker) + + config = {"model": self.valid_model} + + result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + assert result["prompt_type"] == "simple" + assert isinstance(keys, list) + + def test_invalid_prompt_type_raises(self, mocker): + class InvalidEnum(DummyPromptType): + def __iter__(self): + return iter([DummyEnumValue("valid")]) + + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = InvalidEnum() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = {"prompt_type": "invalid", "model": self.valid_model} + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_invalid_chat_prompt_config_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "chat_prompt_config": "invalid", + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_simple_mode_invalid_pre_prompt_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "pre_prompt": 123, + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_requires_one_config(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {}, + "completion_prompt_config": {}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_invalid_model_mode(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": []}, + "model": {"mode": "invalid"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_chat_prompt_length_exceeds(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{}] * 11}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_completion_prefix_defaults_set_when_empty(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "hi"}, + "conversation_histories_role": { + "user_prefix": "", + "assistant_prefix": "", + }, + }, + "model": {"mode": "completion"}, + } + + updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config) + + roles = updated["completion_prompt_config"]["conversation_histories_role"] + assert roles["user_prefix"] == "Human" + assert roles["assistant_prefix"] == "Assistant" + + +# ----------------------------- +# validate_post_prompt +# ----------------------------- + + +class TestValidatePostPrompt: + @pytest.mark.parametrize("value", [None, ""]) + def test_post_prompt_defaults(self, value): + config = {"post_prompt": value} + result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) + assert result["post_prompt"] == "" + + def test_post_prompt_invalid_type(self): + config = {"post_prompt": 123} + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py new file mode 100644 index 0000000000..5def29b741 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -0,0 +1,286 @@ +import pytest + +from core.app.app_config.easy_ui_based_app.variables.manager import ( + BasicVariablesConfigManager, +) +from dify_graph.variables.input_entities import VariableEntityType + + +class TestBasicVariablesConfigManagerConvert: + def test_convert_empty_config(self): + config = {} + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + def test_convert_external_data_tools_enabled_and_disabled(self, mocker): + config = { + "external_data_tools": [ + {"enabled": False}, + { + "enabled": True, + "variable": "ext_var", + "type": "tool_type", + "config": {"k": "v"}, + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert len(external) == 1 + assert external[0].variable == "ext_var" + assert external[0].type == "tool_type" + + def test_convert_user_input_form_variable_types(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "variable": "name", + "label": "Name", + "description": "desc", + "required": True, + "max_length": 50, + } + }, + { + VariableEntityType.SELECT: { + "variable": "choice", + "label": "Choice", + "options": ["a", "b"], + } + }, + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + "config": {"x": 1}, + } + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert len(variables) == 2 + assert len(external) == 1 + + def test_convert_external_data_tool_without_config_skipped(self): + config = { + "user_input_form": [ + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + } + } + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + +class TestValidateVariablesAndSetDefaults: + def test_validate_sets_empty_user_input_form_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"] == [] + assert "user_input_form" in keys + + def test_validate_user_input_form_not_list_raises(self): + config = {"user_input_form": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_invalid_key_raises(self): + config = {"user_input_form": [{"invalid": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_label_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_label_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_variable_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_variable_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + @pytest.mark.parametrize( + "variable_name", + ["1invalid", "invalid space", "", None], + ) + def test_validate_variable_invalid_pattern_raises(self, variable_name): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": variable_name, + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_required_default_and_type(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + } + } + ] + } + + updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False + + def test_validate_required_not_bool_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + "required": "yes", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_default_not_in_options_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": ["a", "b"], + "default": "c", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_not_list_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": "not_list", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + +class TestValidateExternalDataToolsAndSetDefaults: + def test_validate_sets_empty_external_data_tools_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + assert updated["external_data_tools"] == [] + assert "external_data_tools" in keys + + def test_validate_external_data_tools_not_list_raises(self): + config = {"external_data_tools": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_disabled_tool_skipped(self, mocker): + config = {"external_data_tools": [{"enabled": False}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + spy.assert_not_called() + assert updated["external_data_tools"][0]["enabled"] is False + + def test_validate_enabled_tool_missing_type_raises(self): + config = {"external_data_tools": [{"enabled": True, "config": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_enabled_tool_calls_factory(self, mocker): + config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config) + + spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1}) + + +class TestValidateAndSetDefaultsIntegration: + def test_validate_and_set_defaults_calls_both(self, mocker): + config = {} + + spy_var = mocker.patch.object( + BasicVariablesConfigManager, + "validate_variables_and_set_defaults", + return_value=(config, ["user_input_form"]), + ) + spy_ext = mocker.patch.object( + BasicVariablesConfigManager, + "validate_external_data_tools_and_set_defaults", + return_value=(config, ["external_data_tools"]), + ) + + updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config) + + spy_var.assert_called_once() + spy_ext.assert_called_once() + assert "user_input_form" in keys + assert "external_data_tools" in keys + assert updated == config diff --git a/api/tests/unit_tests/core/app/app_config/features/__init__.py b/api/tests/unit_tests/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py new file mode 100644 index 0000000000..dd00c3defc --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -0,0 +1,115 @@ +import pytest + +from core.app.app_config.entities import TextToSpeechEntity +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager + + +class TestAdditionalFeatureManagers: + def test_opening_statement_validate_defaults(self): + config, keys = OpeningStatementConfigManager.validate_and_set_defaults({}) + assert config["opening_statement"] == "" + assert config["suggested_questions"] == [] + assert set(keys) == {"opening_statement", "suggested_questions"} + + def test_opening_statement_validate_types(self): + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123}) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": "bad"} + ) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": [1]} + ) + + def test_opening_statement_convert(self): + opening, questions = OpeningStatementConfigManager.convert( + {"opening_statement": "hello", "suggested_questions": ["q1"]} + ) + assert opening == "hello" + assert questions == ["q1"] + + def test_retrieval_resource_validate(self): + config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({}) + assert config["retriever_resource"]["enabled"] is False + assert keys == ["retriever_resource"] + + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"}) + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}}) + + def test_retrieval_resource_convert(self): + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False + + def test_speech_to_text_validate_and_convert(self): + config, keys = SpeechToTextConfigManager.validate_and_set_defaults({}) + assert config["speech_to_text"]["enabled"] is False + assert keys == ["speech_to_text"] + + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"}) + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}}) + + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False + + def test_suggested_questions_after_answer_validate_and_convert(self): + config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({}) + assert config["suggested_questions_after_answer"]["enabled"] is False + assert keys == ["suggested_questions_after_answer"] + + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": "bad"} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": "yes"}} + ) + + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) + is True + ) + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}}) + is False + ) + + def test_text_to_speech_validate_and_convert(self): + config, keys = TextToSpeechConfigManager.validate_and_set_defaults({}) + assert config["text_to_speech"]["enabled"] is False + assert keys == ["text_to_speech"] + + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"}) + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}}) + + result = TextToSpeechConfigManager.convert( + {"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}} + ) + assert isinstance(result, TextToSpeechEntity) + assert result.voice == "v" + assert result.language == "en" + + def test_more_like_this_convert_and_validate(self): + config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({}) + assert config["more_like_this"]["enabled"] is False + assert keys == ["more_like_this"] + + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False + with pytest.raises(ValueError): + MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"}) diff --git a/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py new file mode 100644 index 0000000000..e99852cf76 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py @@ -0,0 +1,180 @@ +from collections import UserDict +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager + + +class TestBaseAppConfigManager: + @pytest.fixture + def mock_config_dict(self): + return {"key": "value", "another": 123} + + @pytest.fixture + def mock_app_additional_features(self, mocker): + mock_instance = MagicMock() + mocker.patch( + "core.app.app_config.base_app_config_manager.AppAdditionalFeatures", + return_value=mock_instance, + ) + return mock_instance + + @pytest.fixture + def mock_managers(self, mocker): + retrieval = mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + return_value="retrieval_result", + ) + file_upload = mocker.patch( + "core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert", + return_value="file_upload_result", + ) + opening_statement = mocker.patch( + "core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert", + return_value=("opening_result", "suggested_result"), + ) + suggested_after = mocker.patch( + "core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert", + return_value="suggested_after_result", + ) + more_like_this = mocker.patch( + "core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert", + return_value="more_like_this_result", + ) + speech_to_text = mocker.patch( + "core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert", + return_value="speech_to_text_result", + ) + text_to_speech = mocker.patch( + "core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert", + return_value="text_to_speech_result", + ) + + return { + "retrieval": retrieval, + "file_upload": file_upload, + "opening_statement": opening_statement, + "suggested_after": suggested_after, + "more_like_this": more_like_this, + "speech_to_text": speech_to_text, + "text_to_speech": text_to_speech, + } + + @pytest.mark.parametrize( + ("app_mode", "expected_is_vision"), + [ + ("CHAT", True), + ("COMPLETION", True), + ("AGENT_CHAT", True), + ("OTHER", False), + ], + ) + def test_convert_features_all_modes( + self, + mocker, + mock_config_dict, + mock_app_additional_features, + mock_managers, + app_mode, + expected_is_vision, + ): + # Arrange + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode) + + # Assert + assert result == mock_app_additional_features + mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["file_upload"].assert_called_once() + _, kwargs = mock_managers["file_upload"].call_args + assert kwargs["config"] == dict(mock_config_dict.items()) + assert kwargs["is_vision"] is expected_is_vision + + mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items())) + + def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + empty_config = {} + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(empty_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called + + @pytest.mark.parametrize( + "invalid_config", + [ + None, + "string", + 123, + 12.34, + [], + ], + ) + def test_convert_features_invalid_config_raises(self, invalid_config): + # Act & Assert + with pytest.raises((TypeError, AttributeError)): + BaseAppConfigManager.convert_features(invalid_config, "CHAT") + + def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict): + # Arrange + mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + side_effect=RuntimeError("manager failure"), + ) + + # Act & Assert + with pytest.raises(RuntimeError): + BaseAppConfigManager.convert_features(mock_config_dict, "CHAT") + + def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + class CustomMapping(UserDict): + pass + + custom_config = CustomMapping({"a": 1}) + + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(custom_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py new file mode 100644 index 0000000000..eafdf99c16 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -0,0 +1,43 @@ +import pytest + +from core.app.app_config.entities import ( + DatasetRetrieveConfigEntity, + PromptTemplateEntity, +) +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +class TestAppConfigEntities: + def test_variable_entity_coerces_none_description_and_options(self): + entity = VariableEntity( + variable="query", + label="Query", + description=None, + type=VariableEntityType.TEXT_INPUT, + options=None, + ) + + assert entity.description == "" + assert entity.options == [] + + def test_variable_entity_rejects_invalid_json_schema(self): + with pytest.raises(ValueError): + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + json_schema={"type": "string", "minLength": "bad"}, + ) + + def test_prompt_template_value_of(self): + assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE + with pytest.raises(ValueError): + PromptTemplateEntity.PromptType.value_of("missing") + + def test_dataset_retrieve_strategy_value_of(self): + assert ( + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single") + == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + ) + with pytest.raises(ValueError): + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing") diff --git a/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py new file mode 100644 index 0000000000..fa128aca87 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py @@ -0,0 +1,222 @@ +import pytest + +from core.app.app_config.workflow_ui_based_app.variables.manager import ( + WorkflowVariablesConfigManager, +) + +# ============================= +# Fixtures +# ============================= + + +@pytest.fixture +def mock_workflow(mocker): + workflow = mocker.MagicMock() + workflow.graph_dict = {"nodes": []} + return workflow + + +@pytest.fixture +def mock_variable_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity") + + +@pytest.fixture +def mock_rag_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity") + + +# ============================= +# Test Convert (user_input_form) +# ============================= + + +class TestWorkflowVariablesConfigManagerConvert: + def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity): + # Arrange + input_variables = [{"name": "var1"}, {"name": "var2"}] + mock_workflow.user_input_form.return_value = input_variables + mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [{"validated": v} for v in input_variables] + assert mock_variable_entity.model_validate.call_count == 2 + + def test_convert_empty_list(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [] + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [] + mock_variable_entity.model_validate.assert_not_called() + + def test_convert_none_returned_raises(self, mock_workflow): + # Arrange + mock_workflow.user_input_form.return_value = None + + # Act & Assert + with pytest.raises(TypeError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [{"invalid": "data"}] + mock_variable_entity.model_validate.side_effect = ValueError("validation error") + + # Act & Assert + with pytest.raises(ValueError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + +# ============================= +# Test convert_rag_pipeline_variable +# ============================= + + +class TestWorkflowVariablesConfigManagerConvertRag: + def test_no_rag_pipeline_variables(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = [] + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_rag_pipeline_none(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = None + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}] + + def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + @pytest.mark.parametrize( + ("belong_to_node_id", "expected_count"), + [ + ("node1", 1), + ("shared", 1), + ("other_node", 0), + ], + ) + def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": belong_to_node_id}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == expected_count + + def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + + def test_validation_error_propagates(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed") + + # Act & Assert + with pytest.raises(RuntimeError): + WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py new file mode 100644 index 0000000000..7c21b00966 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -0,0 +1,12 @@ +from core.app.entities.queue_entities import QueueStopEvent + + +class TestQueueEntities: + def test_get_stop_reason_for_known_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + assert event.get_stop_reason() == "Stopped by user." + + def test_get_stop_reason_for_unknown_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + event.stopped_by = "unknown" + assert event.get_stop_reason() == "Stopped by unknown reason." diff --git a/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py new file mode 100644 index 0000000000..1e0ef6d6d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py @@ -0,0 +1,17 @@ +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity + + +class TestRagPipelineInvokeEntity: + def test_defaults_and_fields(self): + entity = RagPipelineInvokeEntity( + pipeline_id="pipe-1", + application_generate_entity={"foo": "bar"}, + user_id="user-1", + tenant_id="tenant-1", + workflow_id="workflow-1", + streaming=True, + ) + + assert entity.workflow_execution_id is None + assert entity.workflow_thread_pool_id is None + assert entity.streaming is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py new file mode 100644 index 0000000000..8ecab3199c --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -0,0 +1,78 @@ +from core.app.entities.task_entities import ( + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + StreamEvent, +) +from dify_graph.enums import WorkflowNodeExecutionStatus + + +class TestTaskEntities: + def test_node_start_to_ignore_detail_dict(self): + data = NodeStartStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + created_at=1, + ) + response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_STARTED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["extras"] == {} + + def test_node_finish_to_ignore_detail_dict(self): + data = NodeFinishStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ) + response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_FINISHED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["outputs"] is None + assert payload["data"]["files"] == [] + + def test_node_retry_to_ignore_detail_dict(self): + data = NodeRetryStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.RETRY, + elapsed_time=0.1, + created_at=1, + finished_at=2, + retry_index=2, + ) + response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_RETRY.value + assert payload["data"]["retry_index"] == 2 + assert payload["data"]["outputs"] is None diff --git a/api/tests/unit_tests/core/app/features/test_annotation_reply.py b/api/tests/unit_tests/core/app/features/test_annotation_reply.py new file mode 100644 index 0000000000..e721a77079 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_annotation_reply.py @@ -0,0 +1,163 @@ +import logging +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature + + +class TestAnnotationReplyFeature: + def test_query_returns_none_when_setting_missing(self): + feature = AnnotationReplyFeature() + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = None + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_none_when_binding_missing(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace(collection_binding_detail=None) + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = annotation_setting + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_annotation_and_records_history_for_api(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result == annotation + mock_annotation_service.add_annotation_history.assert_called_once() + _, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "api" + assert score == 0.8 + + def test_query_returns_annotation_and_records_history_for_console(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=0.5, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=None, + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.EXPLORE, + ) + + assert result == annotation + _, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "console" + + def test_query_logs_and_returns_none_on_exception(self, caplog): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1") + mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom") + + with caplog.at_level(logging.WARNING): + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + assert "Query annotation failed" in caplog.text diff --git a/api/tests/unit_tests/core/app/features/test_hosting_moderation.py b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py new file mode 100644 index 0000000000..01194c16f5 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature + + +class TestHostingModerationFeature: + def test_check_aggregates_text_and_calls_moderation(self): + application_generate_entity = Mock() + application_generate_entity.model_conf = {"model": "mock"} + application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1") + + prompt_messages = [ + SimpleNamespace(content="hello"), + SimpleNamespace(content=123), + SimpleNamespace(content="world"), + ] + + with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check: + mock_check.return_value = True + + feature = HostingModerationFeature() + result = feature.check(application_generate_entity, prompt_messages) + + assert result is True + mock_check.assert_called_once_with( + tenant_id="tenant-1", + model_config=application_generate_entity.model_conf, + text="hello\nworld\n", + ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py new file mode 100644 index 0000000000..c6d820dbc9 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -0,0 +1,19 @@ +from core.app.layers.suspend_layer import SuspendLayer +from dify_graph.graph_events.graph import GraphRunPausedEvent + + +class TestSuspendLayer: + def test_on_event_accepts_paused_event(self): + layer = SuspendLayer() + assert layer.is_paused() is False + layer.on_graph_start() + assert layer.is_paused() is False + layer.on_event(GraphRunPausedEvent()) + assert layer.is_paused() is True + + def test_on_event_ignores_other_events(self): + layer = SuspendLayer() + layer.on_graph_start() + initial_state = layer.is_paused() + layer.on_event(object()) + assert layer.is_paused() is initial_state diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py new file mode 100644 index 0000000000..c87eec1508 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -0,0 +1,98 @@ +from unittest.mock import Mock, patch + +from core.app.layers.timeslice_layer import TimeSliceLayer +from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import SchedulerCommand + + +class TestTimeSliceLayer: + def test_init_starts_scheduler_when_not_running(self): + scheduler = Mock() + scheduler.running = False + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + _ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock())) + + scheduler.start.assert_called_once() + + def test_on_graph_start_adds_job_for_time_slice(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid, + ): + mock_uuid.return_value.hex = "job-1" + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.on_graph_start() + + assert layer.schedule_id == "job-1" + scheduler.add_job.assert_called_once() + + def test_on_graph_end_removes_job(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.schedule_id = "job-1" + layer.on_graph_end(None) + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_removes_when_stopped(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.stopped = True + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_handles_resource_limit_without_command_channel(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.logger") as mock_logger, + ): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + mock_logger.exception.assert_called_once() + + def test_checker_job_sends_pause_command(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.command_channel = Mock() + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + layer.command_channel.send_command.assert_called_once() + sent_command = layer.command_channel.send_command.call_args[0][0] + assert isinstance(sent_command, GraphEngineCommand) + assert sent_command.command_type == CommandType.PAUSE diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py new file mode 100644 index 0000000000..f9755061d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -0,0 +1,106 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.layers.trigger_post_layer import TriggerPostLayer +from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from models.enums import WorkflowTriggerStatus + + +class TestTriggerPostLayer: + def test_on_event_updates_trigger_log(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"answer": "ok"}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=12, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunSucceededEvent()) + + assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 12 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + + def test_on_event_handles_missing_trigger_log(self): + runtime_state = SimpleNamespace( + outputs={}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=0, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.logger") as mock_logger, + ): + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = None + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="missing", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunFailedEvent(error="boom")) + + mock_logger.exception.assert_called_once() + session.commit.assert_not_called() + + def test_on_event_ignores_non_status_events(self): + runtime_state = SimpleNamespace( + outputs={}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=0, + ) + + with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory: + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(Mock()) + + mock_session_factory.create_session.assert_not_called() diff --git a/api/tests/unit_tests/core/app/task_pipeline/__init__.py b/api/tests/unit_tests/core/app/task_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py new file mode 100644 index 0000000000..e070eb06fd --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.entities.queue_entities import QueueErrorEvent +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.errors.error import QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from models.enums import MessageStatus + + +class TestBasedGenerateTaskPipeline: + @pytest.fixture + def pipeline(self): + app_config = SimpleNamespace( + tenant_id="tenant-1", + app_id="app-1", + sensitive_word_avoidance=None, + ) + app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config) + return BasedGenerateTaskPipeline( + application_generate_entity=app_generate_entity, + queue_manager=Mock(), + stream=True, + ) + + def test_error_to_desc_quota_exceeded(self, pipeline): + message = pipeline._error_to_desc(QuotaExceededError()) + assert "quota" in message.lower() + + def test_handle_error_wraps_invoke_authorization(self, pipeline): + event = QueueErrorEvent(error=InvokeAuthorizationError()) + err = pipeline.handle_error(event=event) + assert isinstance(err, InvokeAuthorizationError) + assert str(err) == "Incorrect API key provided" + + def test_handle_error_preserves_invoke_error(self, pipeline): + event = QueueErrorEvent(error=InvokeError("bad")) + err = pipeline.handle_error(event=event) + assert err is event.error + + def test_handle_error_updates_message_when_found(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + message = SimpleNamespace(status=MessageStatus.NORMAL, error=None) + session = Mock() + session.scalar.return_value = message + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + assert message.status == MessageStatus.ERROR + assert message.error == "oops" + + def test_handle_error_returns_err_when_message_missing(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + session = Mock() + session.scalar.return_value = None + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + + def test_error_to_stream_response_and_ping(self, pipeline): + error_response = pipeline.error_to_stream_response(ValueError("boom")) + ping_response = pipeline.ping_stream_response() + + assert error_response.task_id == "task-1" + assert ping_response.task_id == "task-1" + + def test_handle_output_moderation_when_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("filtered", True) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result == "filtered" + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None + + def test_handle_output_moderation_when_not_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("safe", False) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result is None + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py new file mode 100644 index 0000000000..155e6f2c73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -0,0 +1,1228 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppStreamResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AudioTrunk +from dify_graph.file.enums import FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent +from models.model import AppMode + + +class _DummyModelConf: + def __init__(self) -> None: + self.model = "mock" + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock", model="mock"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_entity(entity_cls, app_mode: AppMode): + app_config = _make_app_config(app_mode) + return entity_cls.model_construct( + task_id="task", + app_config=app_config, + model_conf=_DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +class TestEasyUiBasedGenerateTaskPipeline: + def test_to_blocking_response_chat(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_to_blocking_response_completion(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_listen_audio_msg_returns_none_when_no_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_process_stream_response_handles_chunks_and_end(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[TextPromptMessageContent(data="hi"), TextPromptMessageContent(data="yo")] + ), + ), + ) + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + events = [ + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline.handle_output_moderation_when_task_finished = lambda completion: None + pipeline._message_end_to_stream_response = lambda: "end" + pipeline._save_message = lambda **kwargs: None + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "chunk" in responses + assert "replace" in responses + assert any(isinstance(item, PingStreamResponse) for item in responses) + assert responses[-1] == "end" + + def test_handle_output_moderation_chunk_directs_output(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline.output_moderation_handler = _Moderation() + pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is True + assert any(isinstance(event, QueueLLMChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_stop_updates_usage(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + class _ModelType: + def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): + return LLMUsage.from_metadata( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + ) + + class _ModelConf: + def __init__(self) -> None: + self.model = "mock" + self.credentials = {} + self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + + app_config = _make_app_config(AppMode.CHAT) + application_generate_entity = ChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + model_conf=_ModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content="answer") + + calls: list[int] = [] + + class _FakeModelInstance: + def __init__(self, provider_model_bundle, model): + pass + + def get_llm_num_tokens(self, messages): + calls.append(1) + return 10 if len(calls) == 1 else 5 + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.ModelInstance", + _FakeModelInstance, + ) + + pipeline._handle_stop(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + + assert pipeline._task_state.llm_result.usage.prompt_tokens == 10 + assert pipeline._task_state.llm_result.usage.completion_tokens == 5 + + def test_record_files_builds_file_payloads(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + message_files = [ + SimpleNamespace( + id="mf-1", + message_id="msg", + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/a.png", + upload_file_id=None, + type="image", + ), + SimpleNamespace( + id="mf-2", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-1", + type="image", + ), + SimpleNamespace( + id="mf-3", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/file.bin", + upload_file_id=None, + type="file", + ), + ] + upload_files = [ + SimpleNamespace( + id="upload-1", + name="local.png", + mime_type="image/png", + size=123, + extension="png", + ) + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else upload_files) + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "signed-url", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "signed-tool", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files + assert len(files) == 3 + + def test_process_stream_response_handles_annotation_and_error(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="agent"), + ), + ) + + events = [ + SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), + SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), + SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), + SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), + SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") + pipeline._agent_thought_to_stream_response = lambda event: "thought" + pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" + pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" + pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline.error_to_stream_response = lambda err: err + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "thought" in responses + assert "file" in responses + assert "agent" in responses + assert isinstance(responses[-1], ValueError) + assert pipeline._task_state.llm_result.message.content == "annotatedagent" + + def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_thought = SimpleNamespace( + id="thought", + position=1, + thought="t", + observation="o", + tool="tool", + tool_labels={}, + tool_input="input", + files=[], + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return agent_thought + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) + + assert response is not None + assert response.id == "thought" + + def test_process_routes_to_stream_and_starts_conversation_name_generation(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_stream_response = lambda generator: "streamed" + + result = pipeline.process() + + assert result == "streamed" + pipeline._message_cycle_manager.generate_conversation_name.assert_called_once_with( + conversation_id="conv", query="hello" + ) + + def test_process_routes_to_blocking_for_completion_mode(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock() + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_blocking_response = lambda generator: "blocking" + + result = pipeline.process() + + assert result == "blocking" + pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() + + def test_to_blocking_response_raises_error_stream_exception(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("stream error")) + + with pytest.raises(ValueError, match="stream error"): + pipeline._to_blocking_response(_gen()) + + def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(RuntimeError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_gen()) + + def test_to_stream_response_wraps_completion_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, CompletionAppStreamResponse) + assert response.message_id == "msg" + + def test_to_stream_response_wraps_chat_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, ChatbotAppStreamResponse) + assert response.conversation_id == "conv" + + def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + assert response.audio == "abc" + + def test_listen_audio_msg_returns_none_for_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + + assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses == ["payload"] + + def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def check_and_get_audio(self): + return AudioTrunk("finish", "") + + inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") + audio_calls = iter([inline_audio, None]) + pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses[0] == inline_audio + assert responses[1] == "payload" + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def __init__(self): + self._events = iter([None, AudioTrunk("responding", "later"), AudioTrunk("finish", "")]) + + def check_and_get_audio(self): + return next(self._events) + + clock = {"value": 0.0} + + def _fake_time(): + clock["value"] += 0.1 + return clock["value"] + + pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.time", _fake_time) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.sleep", lambda _: None) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._task_state.llm_result.message.content = "raw answer" + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_stop = Mock() + pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" + pipeline._save_message = lambda **kwargs: None + pipeline._message_end_to_stream_response = lambda: "end" + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["replace:moderated answer", "end"] + pipeline._handle_stop.assert_called_once() + + def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + handled = {"retriever": 0} + + def _handle_retriever_resources(event): + handled["retriever"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources + pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=retriever_event), + SimpleNamespace(event=SimpleNamespace()), + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + ] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + assert handled["retriever"] == 1 + + def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), + ) + pipeline._handle_output_moderation_chunk = lambda text: True + pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + + def test_process_stream_response_ignores_unsupported_chunk_content_types(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = SimpleNamespace( + prompt_messages=[], + delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), + ) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["ok"] + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = object() + pipeline.queue_manager.listen = lambda: iter([]) + + assert list(pipeline._process_stream_response(publisher=None)) == [] + + def test_save_message_persists_fields_and_emits_trace(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline.start_at = 10.0 + pipeline._model_config = SimpleNamespace(mode="chat") + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( + {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} + ) + + message_obj = SimpleNamespace(id="msg") + conversation_obj = SimpleNamespace(id="conv") + session = Mock() + session.scalar.side_effect = [message_obj, conversation_obj] + trace_manager = SimpleNamespace(add_trace_task=Mock()) + sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptMessageUtil.prompt_messages_to_prompt_for_saving", + lambda mode, prompt_messages: "serialized-prompt", + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptTemplateParser.remove_template_variables", + lambda text: text.replace("{{name}}", "").strip(), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.naive_utc_now", + lambda: datetime(2024, 1, 1, tzinfo=UTC), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.perf_counter", lambda: 15.0 + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.message_was_created.send", + lambda *args, **kwargs: sent_payloads.append((args, kwargs)), + ) + + pipeline._save_message(session=session, trace_manager=trace_manager) + + assert message_obj.message == "serialized-prompt" + assert message_obj.answer == "hello" + assert message_obj.provider_response_latency == 5.0 + assert trace_manager.add_trace_task.called + assert len(sent_payloads) == 1 + + def test_save_message_raises_when_message_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.return_value = None + + with pytest.raises(ValueError, match="message msg not found"): + pipeline._save_message(session=session) + + def test_save_message_raises_when_conversation_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + + with pytest.raises(ValueError, match="Conversation conv not found"): + pipeline._save_message(session=session) + + def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 2}) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.id == "msg" + assert response.metadata["usage"]["prompt_tokens"] == 1 + + def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.files is None + + def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + message_files = [ + SimpleNamespace( + id="mf-local-fallback", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-missing", + type="file", + ), + SimpleNamespace( + id="mf-tool-http", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="http://cdn.example.com/file.txt?x=1", + upload_file_id=None, + type="file", + ), + SimpleNamespace( + id="mf-tool-noext", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/path/toolid", + upload_file_id=None, + type="file", + ), + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else []) + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "local-fallback-signed", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "tool-signed", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files is not None + assert files[0]["url"] == "local-fallback-signed" + assert files[1]["filename"] == "file.txt" + assert files[2]["extension"] == ".bin" + + def test_agent_message_to_stream_response_builds_payload(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + response = pipeline._agent_message_to_stream_response(answer="hello", message_id="msg") + + assert response.id == "msg" + assert response.answer == "hello" + + def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) + + assert response is None + + def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + appended_tokens: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + appended_tokens.append(text) + + pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("next-token") + + assert result is False + assert appended_tokens == ["next-token"] diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_exc.py b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py new file mode 100644 index 0000000000..9ea7e96e73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py @@ -0,0 +1,11 @@ +from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError + + +class TestTaskPipelineExceptions: + def test_record_not_found_error_message(self): + err = RecordNotFoundError("Message", "msg-1") + assert str(err) == "Message with id msg-1 not found" + + def test_workflow_run_not_found_error_message(self): + err = WorkflowRunNotFoundError("run-1") + assert str(err) == "WorkflowRun with id run-1 not found" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index c0c636715d..07ee75ed35 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,12 +1,16 @@ """Unit tests for the message cycle manager optimization.""" +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from flask import current_app +from flask import Flask, current_app -from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from models.model import AppMode class TestMessageCycleManagerOptimization: @@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE mock_session.scalar.assert_called_once() + def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager): + """Return MESSAGE_FILE directly from in-memory cache without opening a DB session.""" + message_cycle_manager._message_has_file.add("cached-message") + + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + result = message_cycle_manager.get_message_event_type("cached-message") + + assert result == StreamEvent.MESSAGE_FILE + mock_session_factory.create_session.assert_not_called() + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization: assert chunk2_response.event == StreamEvent.MESSAGE assert chunk1_response.answer == "Chunk 1" assert chunk2_response.answer == "Chunk 2" + + def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager): + """Return None when completion entities are used for conversation naming. + + Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity. + Returns: None, indicating no name generation for completion apps. + Side effects: None expected. + """ + + class DummyCompletion: + pass + + with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion): + message_cycle_manager._application_generate_entity = DummyCompletion() + result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi") + + assert result is None + + def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager): + """Spawn background generation thread for the first chat message.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True} + flask_app = object() + + class DummyTimer: + def __init__(self, interval, function, args=None, kwargs=None): + self.interval = interval + self.function = function + self.args = args or [] + self.kwargs = kwargs + self.daemon = False + self.started = False + + def start(self): + self.started = True + + with ( + patch( + "core.app.task_pipeline.message_cycle_manager.current_app", + new=SimpleNamespace(_get_current_object=lambda: flask_app), + ), + patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer), + ): + thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello") + + assert isinstance(thread, DummyTimer) + assert thread.interval == 1 + assert thread.function == message_cycle_manager._generate_conversation_name_worker + assert thread.started is True + assert thread.daemon is True + assert thread.kwargs["flask_app"] is flask_app + assert thread.kwargs["conversation_id"] == "conv-1" + assert thread.kwargs["query"] == "hello" + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + + def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager): + """Skip thread creation when auto naming is disabled but still mark conversation as not new.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False} + + with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer: + result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello") + + assert result is None + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + mock_timer.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager): + """Return early when the conversation cannot be found.""" + flask_app = Flask(__name__) + db_session = Mock() + db_session.scalar.return_value = None + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager): + """Return early when non-completion conversation has no app relation.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id") + db_session = Mock() + db_session.scalar.return_value = conversation + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager): + """Use cached conversation name when present and avoid LLM call.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = b"cached-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "cached-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_llm_generator.generate_conversation_name.assert_not_called() + mock_redis.setex.assert_not_called() + + def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager): + """Generate conversation name and write it to redis cache on cache miss.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.return_value = "generated-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "generated-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + """Fallback to truncated query when LLM generation fails.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + long_query = "q" * 60 + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, + patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") + mock_dify_config.DEBUG = True + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + + assert conversation.name == (long_query[:47] + "...") + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_logger.exception.assert_called_once() + + def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): + """Populate task metadata from annotation reply events. + + Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService. + Returns: The fetched annotation object. + Side effects: Updates metadata.annotation_reply with id and account name. + """ + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + annotation = SimpleNamespace( + id="ann-1", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = annotation + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="ann-1") + ) + + assert result == annotation + assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1" + assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice" + + def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager): + """Return None and keep metadata unchanged when annotation is not found.""" + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = None + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="missing") + ) + + assert result is None + assert message_cycle_manager._task_state.metadata.annotation_reply is None + + def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager): + """Merge retriever resources, deduplicate, and preserve ordering positions. + + Args: message_cycle_manager with show_retrieve_source enabled and existing metadata. + Returns: None. + Side effects: Updates metadata.retriever_resources with unique items and positions. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing])) + + duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2") + + event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource]) + message_cycle_manager.handle_retriever_resources(event) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2 + + def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager): + """Build a stream response with a signed tool file URL. + + Args: message_cycle_manager with mocked Session/db and sign_tool_file. + Returns: MessageStreamResponse with signed url and belongs_to normalized to user. + Side effects: Calls sign_tool_file for tool file ids. + """ + message_cycle_manager._application_generate_entity.task_id = "task-1" + + message_file = SimpleNamespace( + id="file-1", + type="image", + belongs_to=None, + url="tool://file.verylongextension", + message_id="msg-1", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-url" + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1")) + + assert response.url == "signed-url" + assert response.belongs_to == "user" + mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin") + + def test_handle_retriever_resources_requires_features(self, message_cycle_manager): + """Raise when retriever resources are handled without feature config. + + Args: message_cycle_manager with additional_features unset and empty metadata. + Raises: ValueError when show_retrieve_source configuration is missing. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with pytest.raises(ValueError): + message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[])) + + def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager): + """Ignore null resource entries while preserving valid resources.""" + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[])) + resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + + message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource])) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + + def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager): + """Use original URL when message file URL is already HTTP.""" + message_cycle_manager._application_generate_entity.task_id = "task-http" + message_file = SimpleNamespace( + id="file-http", + type="image", + belongs_to="assistant", + url="http://example.com/pic.png", + message_id="msg-http", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-http") + ) + + assert response is not None + assert response.url == "http://example.com/pic.png" + assert "msg-http" in message_cycle_manager._message_has_file + + def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager): + """Default tool file extension to .bin when URL has no extension part.""" + message_cycle_manager._application_generate_entity.task_id = "task-bin" + message_file = SimpleNamespace( + id="file-bin", + type="file", + belongs_to="assistant", + url="tool-file-id", + message_id="msg-bin", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-bin-url" + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-bin") + ) + + assert response is not None + assert response.url == "signed-bin-url" + mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin") + + def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager): + """Return None when message file lookup does not find a record.""" + session = Mock() + session.scalar.return_value = None + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing")) + + assert response is None + + def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager): + """Include the provided replacement reason in the stream payload.""" + response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation") + + assert response.answer == "replaced" + assert response.reason == "moderation" diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py new file mode 100644 index 0000000000..fb76f22a2a --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -0,0 +1,43 @@ +from unittest.mock import patch + +from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime + + +class TestDifyWorkflowFileRuntime: + def test_runtime_properties_and_helpers(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + + runtime = DifyWorkflowFileRuntime() + + assert runtime.files_url == "http://files" + assert runtime.internal_files_url == "http://internal" + assert runtime.secret_key == "secret" + assert runtime.files_access_timeout == 123 + assert runtime.multimodal_send_format == "url" + + with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get: + mock_get.return_value = "response" + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch("core.app.workflow.file_runtime.storage.load") as mock_load: + mock_load.return_value = b"data" + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign: + mock_sign.return_value = "signed" + assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed" + mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False) + + def test_bind_runtime_registers_instance(self): + with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set: + bind_dify_workflow_file_runtime() + + mock_set.assert_called_once() + runtime = mock_set.call_args[0][0] + assert isinstance(runtime, DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py new file mode 100644 index 0000000000..9e742507c6 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import BuiltinNodeTypes + + +class DummyNode: + def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = id + self.config = config + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self.kwargs = kwargs + + +class DummyCodeNode(DummyNode): + @classmethod + def default_code_providers(cls): + return () + + +class DummyTemplateTransformNode(DummyNode): + pass + + +class DummyHttpRequestNode(DummyNode): + pass + + +class DummyKnowledgeRetrievalNode(DummyNode): + pass + + +class DummyDocumentExtractorNode(DummyNode): + pass + + +class TestDifyNodeFactory: + @staticmethod + def _stub_node_resolution(monkeypatch, node_class): + monkeypatch.setattr( + "core.workflow.node_factory.resolve_workflow_node_class", + lambda **_kwargs: node_class, + ) + + def _factory(self, monkeypatch): + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100) + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u") + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key") + + run_context = build_dify_run_context( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + return DifyNodeFactory( + graph_init_params=SimpleNamespace(run_context=run_context), + graph_runtime_state=SimpleNamespace(), + ) + + def test_create_node_unknown_type(self, monkeypatch): + factory = self._factory(monkeypatch) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": "unknown"}}) + + def test_create_node_missing_mapping(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {}) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_missing_latest_class(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr( + "core.workflow.node_factory.get_node_type_classes_mapping", + lambda: {BuiltinNodeTypes.START: {"1": None}}, + ) + monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest") + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_selects_versioned_class(self, monkeypatch): + factory = self._factory(monkeypatch) + selected_versions: list[tuple[str, str]] = [] + + class DummyNodeV2(DummyNode): + pass + + def _get_mapping(): + selected_versions.append(("snapshot", "called")) + return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}} + + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}}) + + assert isinstance(node, DummyNodeV2) + assert node.id == "node-1" + assert selected_versions == [("snapshot", "called")] + + def test_create_node_code_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyCodeNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}}) + + assert isinstance(node, DummyCodeNode) + assert node.id == "node-1" + + def test_create_node_template_transform_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) + + assert isinstance(node, DummyTemplateTransformNode) + assert "template_renderer" in node.kwargs + + def test_create_node_http_request_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyHttpRequestNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}) + + assert isinstance(node, DummyHttpRequestNode) + assert "http_request_config" in node.kwargs + + def test_create_node_knowledge_retrieval_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}) + + assert isinstance(node, DummyKnowledgeRetrievalNode) + assert node.kwargs == {} + + def test_create_node_document_extractor_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}}) + + assert isinstance(node, DummyDocumentExtractorNode) + assert "unstructured_api_config" in node.kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py new file mode 100644 index 0000000000..0565f4cfe9 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from core.app.workflow.layers.observability import ObservabilityLayer +from dify_graph.enums import BuiltinNodeTypes + + +class TestObservabilityLayerExtras: + def test_init_tracer_enabled_sets_tracer(self, monkeypatch): + tracer = object() + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer) + + layer = ObservabilityLayer() + + assert layer._is_disabled is False + assert layer._tracer is tracer + + def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + def _raise(*_args, **_kwargs): + raise RuntimeError("tracer init failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + assert layer._tracer is None + assert "Failed to get OpenTelemetry tracer" in caplog.text + + def test_init_tracer_disables_when_otel_disabled(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + + def test_get_parser_uses_registry_when_node_type_matches(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL)) + + assert parser is layer._parsers[BuiltinNodeTypes.TOOL] + + def test_get_parser_defaults_when_node_type_missing(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=None)) + + assert parser is layer._default_parser + + def test_on_graph_start_clears_contexts(self): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_start() + + assert layer._node_contexts == {} + + def test_on_event_is_noop(self): + layer = ObservabilityLayer() + + layer.on_event(object()) + + def test_on_graph_end_clears_unfinished_contexts(self, caplog): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_end(error=None) + + assert layer._node_contexts == {} + assert "node spans were not properly ended" in caplog.text + + def test_on_node_run_start_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + layer._tracer = None + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object()) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self): + layer = ObservabilityLayer() + layer._is_disabled = False + calls: list[str] = [] + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called")) + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert calls == [] + + def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + def _raise(*_args, **_kwargs): + raise RuntimeError("start failed") + + layer._tracer = SimpleNamespace(start_span=_raise) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert "Failed to create OpenTelemetry span for node" in caplog.text + + def test_on_node_run_end_without_context_noop(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None) + + assert "exec" in layer._node_contexts + + def test_on_node_run_end_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_calls_span_end(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + ended: list[str] = [] + + class _Parser: + def parse(self, **_kwargs): + return None + + span = SimpleNamespace(end=lambda: ended.append("ended")) + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert ended == ["ended"] + assert "exec" not in layer._node_contexts + + def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + class _Parser: + def parse(self, **_kwargs): + return None + + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token") + + def _raise(*_args, **_kwargs): + raise RuntimeError("detach failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert "Failed to detach OpenTelemetry token" in caplog.text + assert "exec" not in layer._node_contexts + + def test_on_node_run_start_and_end_creates_span(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + + span = SimpleNamespace(end=lambda: None) + tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span) + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token") + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None) + + layer._tracer = tracer + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + + layer.on_node_run_start(node) + assert "exec" in layer._node_contexts + + layer.on_node_run_end(node, error=None) + assert "exec" not in layer._node_contexts diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py new file mode 100644 index 0000000000..45f6a0c7a1 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + SystemVariableKey, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from dify_graph.graph_events.graph import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from dify_graph.graph_events.node import ( + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from dify_graph.node_events import NodeRunResult +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.system_variable import SystemVariable + + +class _RepoRecorder: + def __init__(self) -> None: + self.saved: list[object] = [] + self.saved_exec_data: list[object] = [] + + def save(self, entity): + self.saved.append(entity) + + def save_execution_data(self, entity): + self.saved_exec_data.append(entity) + + +def _naive_utc_now() -> datetime: + return datetime.now(UTC).replace(tzinfo=None) + + +def _make_layer( + system_variable: SystemVariable | None = None, + *, + extras: dict | None = None, + trace_manager: object | None = None, +): + system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id") + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0) + read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) + + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=SimpleNamespace(app_id="app", tenant_id="tenant"), + inputs={"foo": "bar"}, + files=[], + user_id="user", + stream=False, + invoke_from=None, + trace_manager=None, + workflow_execution_id="run-id", + extras=extras or {}, + call_depth=0, + ) + + workflow_info = PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={"nodes": [], "edges": []}, + ) + + workflow_execution_repo = _RepoRecorder() + workflow_node_execution_repo = _RepoRecorder() + + layer = WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=workflow_info, + workflow_execution_repository=workflow_execution_repo, + workflow_node_execution_repository=workflow_node_execution_repo, + trace_manager=trace_manager, + ) + layer.initialize(read_only_state, command_channel=None) + + return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state + + +class TestWorkflowPersistenceLayer: + def test_on_graph_start_resets_state(self): + layer, _, _, _ = _make_layer() + layer._workflow_execution = object() + layer._node_execution_cache["cached"] = object() + layer._node_snapshots["cached"] = object() + layer._node_sequence = 9 + + layer.on_graph_start() + + assert layer._workflow_execution is None + assert layer._node_execution_cache == {} + assert layer._node_snapshots == {} + assert layer._node_sequence == 0 + + def test_get_execution_id_requires_system_variable(self): + system_variable = SystemVariable(workflow_execution_id=None) + layer, _, _, _ = _make_layer(system_variable) + + with pytest.raises(ValueError, match="workflow_execution_id must be provided"): + layer._get_execution_id() + + def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch): + layer, _, _, _ = _make_layer() + + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda inputs: inputs, + ) + + inputs = layer._prepare_workflow_inputs() + + assert "sys.conversation_id" not in inputs + assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id" + + def test_fail_running_node_executions_marks_failed(self): + layer, _, node_repo, _ = _make_layer() + + execution = WorkflowNodeExecution( + id="exec-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[execution.id] = execution + + layer._fail_running_node_executions(error_message="boom") + + assert execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_repo.saved + + def test_handle_graph_run_started_saves_execution(self): + layer, exec_repo, _, _ = _make_layer() + + layer._handle_graph_run_started() + + assert exec_repo.saved + + def test_handle_graph_run_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 3 + runtime_state.node_run_steps = 2 + runtime_state.outputs = {"out": "v"} + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.SUCCEEDED + assert saved.total_tokens == 3 + assert saved.total_steps == 2 + + def test_handle_graph_run_partial_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 5 + runtime_state.node_run_steps = 4 + runtime_state._graph_execution = SimpleNamespace(exceptions_count=2) + + layer._handle_graph_run_partial_succeeded( + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2) + ) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert saved.exceptions_count == 2 + assert saved.total_tokens == 5 + + def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self): + trace_tasks: list[object] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager) + layer._handle_graph_run_started() + + running = WorkflowNodeExecution( + id="node-exec", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[running.id] = running + + layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1)) + + assert node_repo.saved + assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED + assert trace_tasks + + def test_handle_graph_run_aborted_sets_status(self): + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + + layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.STOPPED + assert saved.error_message + + def test_handle_graph_run_paused_updates_outputs(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 7 + runtime_state.node_run_steps = 5 + + layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PAUSED + assert saved.outputs == {"pause": True} + assert saved.finished_at is None + + def test_handle_node_started_and_retry(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + predecessor_node_id="prev", + in_iteration_id="iter", + in_loop_id="loop", + ) + layer._handle_node_started(start_event) + + assert node_repo.saved + assert "exec" in layer._node_execution_cache + assert layer._node_snapshots["exec"].node_id == "node" + + retry_event = NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + layer._handle_node_retry(retry_event) + assert node_repo.saved_exec_data + + def test_handle_node_result_events_update_execution(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={}) + success_event = NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + node_run_result=result, + ) + layer._handle_node_succeeded(success_event) + + failed_event = NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="boom", + node_run_result=result, + ) + layer._handle_node_failed(failed_event) + + exception_event = NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="err", + node_run_result=result, + ) + layer._handle_node_exception(exception_event) + + assert node_repo.saved_exec_data + + def test_handle_node_pause_requested_skips_outputs(self): + layer, _, _, _ = _make_layer() + layer._handle_graph_run_started() + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + domain_execution = layer._node_execution_cache["exec"] + domain_execution.inputs = {"old": True} + + result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={}) + pause_event = NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + reason=SchedulingPause(message="pause"), + node_run_result=result, + ) + layer._handle_node_pause_requested(pause_event) + + assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED + assert domain_execution.inputs == {"old": True} + + def test_get_node_execution_raises_for_missing(self): + layer, _, _, _ = _make_layer() + with pytest.raises(ValueError, match="Node execution not found"): + layer._get_node_execution("missing") + + def test_get_workflow_execution_raises_when_uninitialized(self): + layer, _, _, _ = _make_layer() + + with pytest.raises(ValueError, match="workflow execution not initialized"): + layer._get_workflow_execution() + + def test_next_node_sequence_increments(self): + layer, _, _, _ = _make_layer() + assert layer._next_node_sequence() == 1 + assert layer._next_node_sequence() == 2 + + def test_on_graph_end_is_noop(self): + layer, _, _, _ = _make_layer() + + assert layer.on_graph_end(error=None) is None + + def test_on_event_dispatches_to_all_known_handlers(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_graph_run_started = _record("started") + layer._handle_graph_run_succeeded = _record("succeeded") + layer._handle_graph_run_partial_succeeded = _record("partial") + layer._handle_graph_run_failed = _record("failed") + layer._handle_graph_run_aborted = _record("aborted") + layer._handle_graph_run_paused = _record("paused") + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + layer._handle_node_succeeded = _record("node_succeeded") + layer._handle_node_failed = _record("node_failed") + layer._handle_node_exception = _record("node_exception") + layer._handle_node_pause_requested = _record("node_paused") + + node_result = NodeRunResult() + now = _naive_utc_now() + events = [ + GraphRunStartedEvent(), + GraphRunSucceededEvent(outputs={"ok": True}), + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1), + GraphRunFailedEvent(error="boom", exceptions_count=1), + GraphRunAbortedEvent(reason="stop", outputs={"x": 1}), + GraphRunPausedEvent(outputs={"pause": True}), + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + ), + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + error="retry", + retry_index=1, + ), + NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + node_run_result=node_result, + ), + NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="failed", + node_run_result=node_result, + ), + NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="error", + node_run_result=node_result, + ), + NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + reason=SchedulingPause(message="pause"), + node_run_result=node_result, + ), + ] + expected_order = [ + "started", + "succeeded", + "partial", + "failed", + "aborted", + "paused", + "node_started", + "node_retry", + "node_succeeded", + "node_failed", + "node_exception", + "node_paused", + ] + + for event in events: + layer.on_event(event) + + assert called == expected_order + + def test_on_event_dispatches_retry_before_started_for_retry_event(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + + layer.on_event( + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + ) + + assert called == ["node_retry"] + + def test_enqueue_trace_task_skips_when_disabled(self): + trace_tasks: list[object] = [] + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + assert exec_repo.saved + assert not trace_tasks diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 5ceaa08893..ae5638784c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -110,7 +110,7 @@ def test_encrypt_tool_parameters(): assert encrypted["plain"] == "x" -def test_decrypt_tool_parameters_cache_hit_and_miss(): +def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch): manager = _build_manager() with ( @@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache(): mock_delete.assert_called_once() -def test_configuration_manager_decrypt_suppresses_errors(): +def test_configuration_manager_decrypt_suppresses_errors(monkeypatch): manager = _build_manager() with ( patch.object(ToolParameterCache, "get", return_value=None),