diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 02ec96f209..5c9bc43992 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index e35e9d9408..0c146c388f 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 3aa1161fd8..f23ee7f89f 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/tests/unit_tests/core/app/apps/__init__.py b/api/tests/unit_tests/core/app/apps/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py b/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py new file mode 100644 index 0000000000..6ca4f60459 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from models.model import AppMode + + +class TestAdvancedChatAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.ADVANCED_CHAT + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + config = kwargs.get("config") if kwargs else args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 2), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 3), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 4), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 5), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 6), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 7), + ), + ): + filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["opening_statement"] == 2 + assert filtered["suggested_questions_after_answer"] == 3 + assert filtered["speech_to_text"] == 4 + assert filtered["text_to_speech"] == 5 + assert filtered["retriever_resource"] == 6 + assert filtered["sensitive_word_avoidance"] == 7 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py new file mode 100644 index 0000000000..8faae3661d --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -0,0 +1,1255 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, ValidationError + +from constants import UUID_NIL +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestAdvancedChatAppGeneratorValidation: + def test_generate_requires_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query is required"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_generate_requires_string_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query must be a string"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}, "query": 123}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_single_iteration_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestAdvancedChatAppGeneratorInternals: + @staticmethod + def _build_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + def test_generate_loads_conversation_and_files(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + + conversation = SimpleNamespace(id="conversation-id") + built_files: list[object] = [] + build_files_called = {"called": False} + captured: dict[str, object] = {} + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.ConversationService.get_conversation", + lambda **kwargs: conversation, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda *args, **kwargs: {"enabled": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.file_factory.build_from_mappings", + lambda **kwargs: build_files_called.update({"called": True}) or built_files, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: kwargs["user_inputs"]) + + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) + or setattr(self, "user_id", user_id) + }, + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user-id" + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=user, + args={ + "query": "hello", + "inputs": {"k": "v"}, + "conversation_id": "conversation-id", + "files": [{"id": "f"}], + }, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert captured["conversation"] is conversation + assert captured["application_generate_entity"].files == built_files + assert build_files_called["called"] is True + + def test_resume_delegates_to_generate(self, monkeypatch): + generator = AdvancedChatAppGenerator() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=self._build_app_config(), + inputs={}, + query="hello", + files=[], + user_id="user", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + captured: dict[str, object] = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"resumed": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.resume( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conversation-id"), + message=SimpleNamespace(id="message-id"), + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=SimpleNamespace(), + pause_state_config=None, + ) + + assert result == {"resumed": True} + assert captured["graph_runtime_state"] is not None + + def test_single_iteration_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow): + prefill_calls.append(workflow) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_iteration_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-1", + user=SimpleNamespace(id="user-id"), + args={"inputs": {"foo": "bar"}}, + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1" + + def test_single_loop_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow): + prefill_calls.append(workflow) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_loop_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-2", + user=SimpleNamespace(id="user-id"), + args=SimpleNamespace(inputs={"foo": "bar"}), + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_loop_run.node_id == "node-2" + + def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-1") + db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) + captured: dict[str, object] = {} + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", lambda *args: (conversation, message)) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 2) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.PauseStatePersistenceLayer", + lambda **kwargs: "pause-layer", + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: captured.update(kwargs) or {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: {"response": response, "invoke_from": invoke_from}, + ) + + pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") + + response = generator._generate( + workflow=SimpleNamespace(features={"feature": True}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=None, + message=None, + stream=False, + pause_state_config=pause_state_config, + ) + + assert response["response"] == {"raw": True} + assert thread_data["started"] is True + assert "pause-layer" in thread_data["kwargs"]["graph_engine_layers"] + assert generator._dialogue_count == 3 + db_session.commit.assert_called_once() + db_session.refresh.assert_called_once_with(conversation) + db_session.close.assert_called_once() + assert captured["draft_var_saver_factory"] == "draft-factory" + + def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-2") + db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + init_records = MagicMock() + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", init_records) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 0) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: response, + ) + + response = generator._generate( + workflow=SimpleNamespace(features={}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert response == {"raw": True} + init_records.assert_not_called() + assert thread_data["started"] is True + db_session.commit.assert_not_called() + db_session.refresh.assert_not_called() + db_session.close.assert_called_once() + + def test_generate_worker_raises_when_workflow_not_found(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock(return_value=None) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="Workflow not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_raises_when_app_not_found_for_internal_call(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + None, + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="App not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_handles_stopped_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise GenerateTaskStoppedError() + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_not_called() + + def test_generate_worker_handles_validation_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _ValidationModel(BaseModel): + value: int + + try: + _ValidationModel(value="invalid") + except ValidationError as error: + validation_error = error + else: + raise AssertionError("validation error should be created") + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise validation_error + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_generate_worker_handles_value_and_unknown_errors(self, monkeypatch): + app_config = self._build_app_config() + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + def _make_runner(error: Exception): + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise error + + return _Runner + + for raised_error in [ValueError("bad input"), RuntimeError("unexpected")]: + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", + _make_runner(raised_error), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.dify_config", SimpleNamespace(DEBUG=True)) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + def test_handle_response_re_raises_value_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs): + _ = kwargs + + def process(self): + raise ValueError("other error") + + logger_exception = MagicMock() + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.logger.exception", logger_exception) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", _Pipeline) + + with pytest.raises(ValueError, match="other error"): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + logger_exception.assert_called_once() + + def test_refresh_model_returns_detached_model(self, monkeypatch): + source_model = SimpleNamespace(id="source-id") + detached_model = SimpleNamespace(id="source-id", detached=True) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, model_type, model_id): + _ = model_type + return detached_model if model_id == "source-id" else None + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) + + refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) + + assert refreshed is detached_model + + def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.SERVICE_API, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT)) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Runner: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def run(self): + from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + raise InvokeAuthorizationError("bad key") + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="end-user-id", session_id="session-id"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + assert queue_manager.publish_error.called + + def test_generate_debugger_enables_retrieve_source(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) + or setattr(self, "user_id", user_id) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user" + + result = generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello\x00", "inputs": {}}, + invoke_from=InvokeFrom.DEBUGGER, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert app_config.additional_features.show_retrieve_source is True + assert captured["application_generate_entity"].query == "hello" + + def test_generate_service_api_sets_parent_message_id(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) + or setattr(self, "user_id", user_id) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models.model import EndUser + + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + user.id = "end-user" + + generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello", "inputs": {}, "parent_message_id": "p1"}, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id="run-id", + streaming=False, + ) + + assert captured["application_generate_entity"].parent_message_id == UUID_NIL diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py new file mode 100644 index 0000000000..5b199e0c52 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -0,0 +1,96 @@ +from collections.abc import Generator + +from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, +) +from dify_graph.enums import WorkflowNodeExecutionStatus + + +class TestAdvancedChatGenerateResponseConverter: + def test_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + assert "usage" not in response["metadata"] + + def test_stream_simple_response_includes_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_start, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_finish, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py similarity index 100% rename from api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py rename to api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..b348ffc33b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -0,0 +1,600 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueMessageReplaceEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import NodeType +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import MessageStatus +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + message = SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ) + conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=conversation, + message=message, + user=user, + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestAdvancedChatGenerateTaskPipeline: + def test_ensure_workflow_initialized_raises(self): + pipeline = _make_pipeline() + + with pytest.raises(ValueError, match="workflow run not initialized"): + pipeline._ensure_workflow_initialized() + + def test_to_blocking_response_returns_message_end(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "done" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"}) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "done" + assert response.data.metadata == {"k": "v"} + + def test_handle_text_chunk_event_updates_state(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager = SimpleNamespace( + message_to_stream_response=lambda **kwargs: MessageEndStreamResponse( + task_id="task", id="message-id", metadata={} + ) + ) + + event = SimpleNamespace(text="hi", from_variable_selector=None) + + responses = list(pipeline._handle_text_chunk_event(event)) + + assert pipeline._task_state.answer == "hi" + assert responses + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + pipeline._database_session = _fake_session + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace()) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_run_id == "run-id" + assert responses == ["started"] + + def test_message_end_to_stream_response_strips_annotation_reply(self): + pipeline = _make_pipeline() + pipeline._task_state.metadata.annotation_reply = AnnotationReply( + id="ann", + account=AnnotationReplyAccount(id="acc", name="acc"), + ) + + response = pipeline._message_end_to_stream_response() + + assert "annotation_reply" not in response.metadata + + def test_handle_output_moderation_chunk_publishes_stop(self): + pipeline = _make_pipeline() + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + pipeline._base_task_pipeline.queue_manager = SimpleNamespace( + publish=lambda event, pub_from: events.append(event) + ) + + result = pipeline._handle_output_moderation_chunk("ignored") + + assert result is True + assert pipeline._task_state.answer == "final" + assert any(isinstance(event, QueueTextChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_node_succeeded_event_records_files(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [ + {"type": "file", "transfer_method": "local"} + ] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + event = SimpleNamespace( + node_type=NodeType.ANSWER, + outputs={"k": "v"}, + node_execution_id="exec", + node_id="node", + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + assert pipeline._recorded_files + + def test_iteration_and_loop_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = ( + lambda **kwargs: "iter_start" + ) + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = ( + lambda **kwargs: "iter_done" + ) + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + + def test_workflow_finish_handlers(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"] + pipeline._persist_human_input_extra_content = lambda **kwargs: None + pipeline._save_message = lambda **kwargs: None + pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id") + + @contextmanager + def _fake_session(): + yield SimpleNamespace(scalar=lambda *args, **kwargs: None) + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) + assert len(succeeded_responses) == 2 + assert isinstance(succeeded_responses[0], MessageEndStreamResponse) + assert succeeded_responses[1] == "finish" + + partial_success_responses = list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ) + ) + assert len(partial_success_responses) == 2 + assert isinstance(partial_success_responses[0], MessageEndStreamResponse) + assert partial_success_responses[1] == "finish" + assert ( + list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0] + == "finish" + ) + assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [ + "pause" + ] + + def test_node_failure_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + failed_event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + exc_event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"] + assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"] + + def test_handle_text_chunk_event_tracks_streaming_metrics(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk") + + event = SimpleNamespace(text="hi", from_variable_selector=["a"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses == ["chunk"] + assert pipeline._task_state.is_streaming_response is True + assert pipeline._task_state.first_token_time is not None + assert pipeline._task_state.last_token_time is not None + assert pipeline._task_state.answer == "hi" + assert published == [queue_message] + + def test_handle_output_moderation_chunk_appends_token(self): + pipeline = _make_pipeline() + seen: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + seen.append(text) + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is False + assert seen == ["token"] + + def test_handle_retriever_and_annotation_events(self): + pipeline = _make_pipeline() + calls = {"retriever": 0, "annotation": 0} + + def _hit_retriever(event): + calls["retriever"] += 1 + + def _hit_annotation(event): + calls["annotation"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever + pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation + + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann") + + assert list(pipeline._handle_retriever_resources_event(retriever_event)) == [] + assert list(pipeline._handle_annotation_reply_event(annotation_event)) == [] + assert calls == {"retriever": 1, "annotation": 1} + + def test_handle_message_replace_event(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + + event = QueueMessageReplaceEvent( + text="new", + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + ) + + assert list(pipeline._handle_message_replace_event(event)) == ["replace"] + + def test_handle_human_input_events(self): + pipeline = _make_pipeline() + persisted: list[str] = [] + pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved") + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=NodeType.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert persisted == ["saved"] + + def test_save_message_strips_markdown_and_sets_usage(self): + pipeline = _make_pipeline() + pipeline._recorded_files = [ + { + "type": "image", + "transfer_method": "remote", + "remote_url": "http://example.com/file.png", + "related_id": "file-id", + } + ] + pipeline._task_state.answer = "![img](url) hello" + pipeline._task_state.is_streaming_response = True + pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1 + pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2 + + message = SimpleNamespace( + id="message-id", + status=MessageStatus.PAUSED, + answer="", + updated_at=None, + provider_response_latency=None, + message_tokens=None, + message_unit_price=None, + message_price_unit=None, + answer_tokens=None, + answer_unit_price=None, + answer_price_unit=None, + total_price=None, + currency=None, + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="end-user", + ) + + class _Session: + def scalar(self, *args, **kwargs): + return message + + def add_all(self, items): + self.items = items + + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state) + + assert message.status == MessageStatus.NORMAL + assert message.answer == "hello" + assert message.message_metadata + + def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._message_end_to_stream_response = lambda: "end" + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION))) + + assert responses == ["end"] + assert saved == ["saved"] + + def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline._message_end_to_stream_response = lambda: "end" + + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent())) + + assert responses == ["replace", "end"] + assert saved == ["saved"] + + def test_dispatch_event_handles_node_exception(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda *args, **kwargs: None + + event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["failed"] diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py new file mode 100644 index 0000000000..a871e8d93b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py @@ -0,0 +1,302 @@ +import uuid +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.agent_chat.app_config_manager import ( + AgentChatAppConfigManager, +) +from core.entities.agent_entities import PlanningStrategy + + +class TestAgentChatAppConfigManagerGetAppConfig: + def test_get_app_config_override_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"ignored": True} + + override_config = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override_config, + ) + + assert result.app_model_config_dict == override_config + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.variables == "variables" + assert result.external_data_variables == "external" + + def test_get_app_config_conversation_specific(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + conversation = mocker.MagicMock() + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=None, + ) + + assert result.app_model_config_dict == app_model_config.to_dict.return_value + assert result.app_model_config_from.value == "conversation-specific-config" + + def test_get_app_config_latest_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=None, + ) + + assert result.app_model_config_from.value == "app-latest-config" + + +class TestAgentChatAppConfigManagerConfigValidate: + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {}, + "user_input_form": {}, + "file_upload": {}, + "prompt_template": {}, + "agent_mode": {}, + "opening_statement": {}, + "suggested_questions_after_answer": {}, + "speech_to_text": {}, + "text_to_speech": {}, + "retriever_resource": {}, + "dataset": {}, + "moderation": {}, + "extra": "value", + } + + def return_with_key(key): + return config, [key] + + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("model"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("file_upload"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=lambda app_mode, cfg: return_with_key("prompt_template"), + ) + mocker.patch.object( + AgentChatAppConfigManager, + "validate_agent_mode_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("opening_statement"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("speech_to_text"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("text_to_speech"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("retriever_resource"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("moderation"), + ) + + filtered = AgentChatAppConfigManager.config_validate("tenant", config) + assert set(filtered.keys()) == { + "model", + "user_input_form", + "file_upload", + "prompt_template", + "agent_mode", + "opening_statement", + "suggested_questions_after_answer", + "speech_to_text", + "text_to_speech", + "retriever_resource", + "dataset", + "moderation", + } + assert "extra" not in filtered + + +class TestValidateAgentModeAndSetDefaults: + def test_defaults_when_missing(self): + config = {} + updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert "agent_mode" in updated + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + assert keys == ["agent_mode"] + + @pytest.mark.parametrize( + "agent_mode", + ["invalid", 123], + ) + def test_agent_mode_type_validation(self, agent_mode): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode}) + + def test_agent_mode_empty_list_defaults(self): + config = {"agent_mode": []} + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + + def test_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}}) + + def test_strategy_must_be_valid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}} + ) + + def test_tools_must_be_list(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}} + ) + + def test_old_tool_dataset_requires_id(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}} + ) + + def test_old_tool_dataset_id_must_be_uuid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}}, + ) + + def test_old_tool_dataset_id_not_exists(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=False, + ) + dataset_id = str(uuid.uuid4()) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}}, + ) + + def test_old_tool_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}}, + ) + + @pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"]) + def test_new_style_tool_requires_fields(self, missing_key): + tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"} + tool.pop(missing_key, None) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [tool]}} + ) + + def test_valid_old_and_new_style_tools(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=True, + ) + dataset_id = str(uuid.uuid4()) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER.value, + "tools": [ + {"dataset": {"id": dataset_id}}, + { + "provider_type": "builtin", + "provider_id": "p1", + "tool_name": "tool", + "tool_parameters": {}, + }, + ], + } + } + + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False + assert updated["agent_mode"]["tools"][1]["enabled"] is False diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py new file mode 100644 index 0000000000..53f26d1592 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -0,0 +1,296 @@ +import contextlib + +import pytest +from pydantic import ValidationError + +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +class DummyAccount: + def __init__(self, user_id): + self.id = user_id + self.session_id = f"session-{user_id}" + + +@pytest.fixture +def generator(mocker): + gen = AgentChatAppGenerator() + mocker.patch( + "core.app.apps.agent_chat.app_generator.current_app", + new=mocker.MagicMock(_get_current_object=mocker.MagicMock()), + ) + mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx") + return gen + + +class TestAgentChatAppGeneratorGenerate: + def test_generate_rejects_blocking_mode(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False) + + def test_generate_requires_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock()) + + def test_generate_rejects_non_string_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": 123, "inputs": {}}, + invoke_from=mocker.MagicMock(), + ) + + def test_generate_override_requires_debugger(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}}, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_success_with_debugger_override(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + invoke_from = InvokeFrom.DEBUGGER + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate", + return_value={"validated": True}, + ) + app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=app_config, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ConversationService.get_conversation", + return_value=mocker.MagicMock(id="conv"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + queue_manager = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=queue_manager, + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = { + "query": "hello", + "inputs": {"name": "world"}, + "conversation_id": "conv", + "model_config": {"model": {"provider": "p"}}, + "files": [{"id": "f1"}], + } + + result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True) + + assert result == {"result": "ok"} + thread_obj.start.assert_called_once() + + def test_generate_without_file_config(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=None, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=mocker.MagicMock(), + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = {"query": "hello", "inputs": {"name": "world"}} + + result = generator.generate( + app_model=app_model, + user=user, + args=args, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == {"result": "ok"} + + +class TestAgentChatAppGeneratorWorker: + @pytest.fixture(autouse=True) + def patch_context(self, mocker): + @contextlib.contextmanager + def ctx_manager(*args, **kwargs): + yield + + mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager) + + def test_generate_worker_handles_generate_task_stopped(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = GenerateTaskStoppedError() + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + queue_manager.publish_error.assert_not_called() + + @pytest.mark.parametrize( + "error", + [ + InvokeAuthorizationError("bad"), + ValidationError.from_exception_data("TestModel", []), + ValueError("bad"), + Exception("bad"), + ], + ) + def test_generate_worker_publishes_errors(self, generator, mocker, error): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = error + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + assert queue_manager.publish_error.called + + def test_generate_worker_logs_value_error_when_debug(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = ValueError("bad") + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True)) + logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py new file mode 100644 index 0000000000..5603115b30 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -0,0 +1,413 @@ +import pytest + +from core.agent.entities import AgentEntity +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey + + +@pytest.fixture +def runner(): + return AgentChatAppRunner() + + +class TestAgentChatAppRunnerRun: + def test_run_app_not_found(self, runner, mocker): + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock()) + generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_moderation_error_direct_output(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad")) + mocker.patch.object(runner, "direct_output") + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + runner.direct_output.assert_called_once() + + def test_run_annotation_reply_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + user_id="user", + invoke_from=mocker.MagicMock(), + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + annotation = mocker.MagicMock(id="anno", content="answer") + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation) + mocker.patch.object(runner, "direct_output") + + queue_manager = mocker.MagicMock() + runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock()) + + queue_manager.publish.assert_called_once() + runner.direct_output.assert_called_once() + + def test_run_hosting_moderation_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=True) + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_model_schema_missing(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = None + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + @pytest.mark.parametrize( + ("mode", "expected_runner"), + [ + (LLMMode.CHAT, "CotChatAgentRunner"), + (LLMMode.COMPLETION, "CotCompletionAgentRunner"), + ], + ) + def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: mode} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + runner_instance.run.assert_called_once() + runner._handle_invoke_result.assert_called_once() + + def test_run_invalid_llm_mode_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + def test_run_function_calling_strategy_selected_by_features(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [ModelFeature.TOOL_CALL] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING + runner_instance.run.assert_called_once() + + def test_run_conversation_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_message_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, mocker.MagicMock(id="conv"), None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_invalid_agent_strategy_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m") + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py new file mode 100644 index 0000000000..02a1e04c98 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py @@ -0,0 +1,162 @@ +from collections.abc import Generator + +from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestAgentChatAppGenerateResponseConverterBlocking: + def test_convert_blocking_full_response(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={"a": 1}, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["answer"] == "answer" + assert result["metadata"] == {"a": 1} + + def test_convert_blocking_simple_response_with_dict_metadata(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={ + "retriever_resources": [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + + def test_convert_blocking_simple_response_with_non_dict_metadata(self): + blocking = ChatbotAppBlockingResponse.model_construct( + task_id="task", + data=ChatbotAppBlockingResponse.Data.model_construct( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + +class TestAgentChatAppGenerateResponseConverterStream: + def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]: + def _gen(): + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=1, + stream_response=PingStreamResponse(task_id="t"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=2, + stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=3, + stream_response=MessageEndStreamResponse( + task_id="t", + id="m1", + metadata={ + "retriever_resources": [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + "summary": "summary", + "extra": "ignored", + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + ), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=4, + stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")), + ) + + return _gen() + + def test_convert_stream_full_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream())) + assert items[0] == "ping" + assert items[1]["event"] == "message" + assert "answer" in items[1] + assert items[2]["event"] == "message_end" + assert items[3]["event"] == "error" + + def test_convert_stream_simple_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream())) + assert items[0] == "ping" + # Assert the message event structure and content at items[1] + assert items[1]["event"] == "message" + assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"] + assert items[2]["event"] == "message_end" + assert "metadata" in items[2] + metadata = items[2]["metadata"] + assert "annotation_reply" not in metadata + assert "usage" not in metadata + assert metadata["retriever_resources"] == [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + "summary": "summary", + } + ] + assert items[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/chat/__init__.py b/api/tests/unit_tests/core/app/apps/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py new file mode 100644 index 0000000000..271d007be6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py @@ -0,0 +1,113 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from models.model import AppMode + + +class TestChatAppConfigManager: + def test_get_app_config_uses_override_dict(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"}) + override = {"model": "override"} + + model_entity = ModelConfigEntity(provider="p", model="m") + prompt_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ) + + with ( + patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])), + ): + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override, + ) + + assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert app_config.app_model_config_dict == override + assert app_config.app_mode == AppMode.CHAT + + def test_config_validate_filters_related_keys(self): + config = {"extra": 1} + + def _add_key(key, value): + def _inner(*args, **kwargs): + config = args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=_add_key("model", 1), + ), + patch( + "core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=_add_key("inputs", 2), + ), + patch( + "core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 3), + ), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=_add_key("prompt", 4), + ), + patch( + "core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=_add_key("dataset", 5), + ), + patch( + "core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 6), + ), + patch( + "core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 7), + ), + patch( + "core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 8), + ), + patch( + "core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 9), + ), + patch( + "core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 10), + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 11), + ), + ): + filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config) + + assert filtered["model"] == 1 + assert filtered["inputs"] == 2 + assert filtered["file_upload"] == 3 + assert filtered["prompt"] == 4 + assert filtered["dataset"] == 5 + assert filtered["opening_statement"] == 6 + assert filtered["suggested_questions_after_answer"] == 7 + assert filtered["speech_to_text"] == 8 + assert filtered["text_to_speech"] == 9 + assert filtered["retriever_resource"] == 10 + assert filtered["sensitive_word_avoidance"] == 11 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py new file mode 100644 index 0000000000..3cdffbb4cd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -0,0 +1,280 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.moderation.base import ModerationError +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from models.model import AppMode + + +class DummyGenerateEntity: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class DummyQueueManager: + def __init__(self, *args, **kwargs): + self.published = [] + + def publish_error(self, error, pub_from): + self.published.append((error, pub_from)) + + def publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestChatAppGenerator: + def test_generate_requires_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_rejects_non_string_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"query": 1, "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_debugger_overrides_model_config(self): + generator = ChatAppGenerator() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + user = SimpleNamespace(id="user-1", session_id="session-1") + args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}} + + with ( + patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None), + patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}), + patch( + "core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config", + return_value=SimpleNamespace( + variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT + ), + ), + patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]), + patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity), + patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager), + patch( + "core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True} + ), + patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})), + patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}), + patch.object( + ChatAppGenerator, + "_init_generate_records", + return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")), + ), + patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}), + patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f), + patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread, + ): + mock_thread.return_value.start.return_value = None + result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False) + + assert result == {"ok": True} + + def test_generate_rejects_model_config_override_for_non_debugger(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + with ( + patch.object( + ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {}) + ), + ): + generator.generate( + app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value), + user=SimpleNamespace(id="u1", session_id="s1"), + args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_worker_handles_exceptions(self): + generator = ChatAppGenerator() + queue_manager = DummyQueueManager() + entity = DummyGenerateEntity(task_id="t1", user_id="u1") + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + assert queue_manager.published + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + +class TestChatAppRunner: + def test_run_raises_when_app_missing(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[] + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None): + with pytest.raises(ValueError): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + def test_run_moderation_error_direct_output(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + mock_direct.assert_called_once() + + def test_run_annotation_reply_short_circuits(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + annotation = SimpleNamespace(id="ann-1", content="answer") + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + queue_manager = DummyQueueManager() + runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1")) + + assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published) + mock_direct.assert_called_once() + + def test_run_returns_when_hosting_moderation_blocks(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None), + patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True), + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) diff --git a/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py new file mode 100644 index 0000000000..01272ba052 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py @@ -0,0 +1,65 @@ +from collections.abc import Generator + +from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestChatAppGenerateResponseConverter: + def test_convert_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + + response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "usage" not in response["metadata"] + + def test_convert_stream_responses(self): + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream())) + assert full[0] == "ping" + assert full[1]["event"] == "message" + assert full[2]["event"] == "error" + + simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert simple[0] == "ping" + assert simple[-1]["event"] == "message_end" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py new file mode 100644 index 0000000000..51f33bac35 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -0,0 +1,162 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.completion.app_runner as module +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + +@pytest.fixture +def runner(): + return CompletionAppRunner() + + +def _build_app_config(dataset=None, external_tools=None, additional_features=None): + app_config = MagicMock() + app_config.app_id = "app1" + app_config.tenant_id = "tenant" + app_config.prompt_template = MagicMock() + app_config.dataset = dataset + app_config.external_data_variables = external_tools or [] + app_config.additional_features = additional_features + app_config.app_model_config_dict = {"file_upload": {"enabled": True}} + return app_config + + +def _build_generate_entity(app_config, file_upload_config=None): + model_conf = MagicMock( + provider_model_bundle="bundle", + model="model", + parameters={"max_tokens": 10}, + stop=["stop"], + ) + return SimpleNamespace( + app_config=app_config, + model_conf=model_conf, + inputs={"qvar": "query_from_input"}, + query="original_query", + files=[], + file_upload_config=file_upload_config, + stream=True, + user_id="user", + invoke_from=MagicMock(), + ) + + +class TestCompletionAppRunner: + def test_run_app_not_found(self, runner, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + with pytest.raises(ValueError): + runner.run(app_generate_entity, MagicMock(), MagicMock()) + + def test_run_moderation_error_outputs_direct(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked")) + runner.direct_output = MagicMock() + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner.direct_output.assert_called_once() + runner._handle_invoke_result.assert_not_called() + + def test_run_hosting_moderation_stops(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner._handle_invoke_result.assert_not_called() + + def test_run_dataset_and_external_tools_flow(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + session.close = MagicMock() + mocker.patch.object(module.db, "session", session) + + retrieve_config = MagicMock(query_variable="qvar") + dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config) + additional_features = MagicMock(show_retrieve_source=True) + app_config = _build_app_config( + dataset=dataset_config, + external_tools=["tool"], + additional_features=additional_features, + ) + + file_upload_config = MagicMock() + file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH + + app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config) + + runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])]) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs) + runner.check_hosting_moderation = MagicMock(return_value=False) + runner.recalc_llm_max_tokens = MagicMock() + runner._handle_invoke_result = MagicMock() + + dataset_retrieval = MagicMock() + dataset_retrieval.retrieve.return_value = ("ctx", ["file1"]) + mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval) + + model_instance = MagicMock() + model_instance.invoke_llm.return_value = "invoke_result" + mocker.patch.object(module, "ModelInstance", return_value=model_instance) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant")) + + dataset_retrieval.retrieve.assert_called_once() + assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input" + runner._handle_invoke_result.assert_called_once() + + def test_run_uses_low_image_detail_default(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config, file_upload_config=None) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + assert ( + runner.organize_prompt_messages.call_args.kwargs["image_detail_config"] + == ImagePromptMessageContent.DETAIL.LOW + ) diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py new file mode 100644 index 0000000000..024bd8f302 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.completion.app_config_manager as module +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from models.model import AppMode + + +class TestCompletionAppConfigManager: + def test_get_app_config_with_override(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + override_config = {"model": {"provider": "override"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_config, + ) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.app_model_config_dict == override_config + assert result.variables == ["v1"] + assert result.external_data_variables == ["ext1"] + assert result.app_mode == AppMode.COMPLETION + + def test_get_app_config_without_override_uses_model_config(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], [])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + assert result.app_model_config_dict == {"model": {"provider": "x"}} + + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {"provider": "x"}, + "variables": ["v"], + "file_upload": {"enabled": True}, + "prompt": {"template": "t"}, + "dataset": {"enabled": True}, + "tts": {"enabled": True}, + "more_like_this": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.ModelConfigManager, + "validate_and_set_defaults", + return_value=(config, ["model"]), + ) + mocker.patch.object( + module.BasicVariablesConfigManager, + "validate_and_set_defaults", + return_value=(config, ["variables"]), + ) + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.PromptTemplateConfigManager, + "validate_and_set_defaults", + return_value=(config, ["prompt"]), + ) + mocker.patch.object( + module.DatasetConfigManager, + "validate_and_set_defaults", + return_value=(config, ["dataset"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.MoreLikeThisConfigManager, + "validate_and_set_defaults", + return_value=(config, ["more_like_this"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = CompletionAppConfigManager.config_validate("tenant", config) + + assert "extra" not in filtered + assert set(filtered.keys()) == { + "model", + "variables", + "file_upload", + "prompt", + "dataset", + "tts", + "more_like_this", + "moderation", + } diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py new file mode 100644 index 0000000000..2714757353 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -0,0 +1,321 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +import core.app.apps.completion.app_generator as module +from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +@pytest.fixture +def generator(mocker): + gen = CompletionAppGenerator() + + mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn) + + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app))) + + thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=thread) + + mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + return gen + + +def _build_app_model(): + return MagicMock(tenant_id="tenant", id="app1", mode="completion") + + +def _build_user(): + return MagicMock(id="user", session_id="session") + + +def _build_app_model_config(): + config = MagicMock(id="cfg") + config.to_dict.return_value = {"model": {"provider": "x"}} + return config + + +class TestCompletionAppGenerator: + def test_generate_invalid_query_type(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": 123, "inputs": {}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + def test_generate_override_not_debugger(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": {}}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + def test_generate_success_no_file_config(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.file_factory, "build_from_mappings") + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_not_called() + + def test_generate_success_with_files(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_called_once() + + def test_generate_override_model_config_debugger(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + override_config = {"model": {"provider": "override"}} + mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": override_config}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert get_app_config.call_args.kwargs["override_config_dict"] == override_config + + def test_generate_more_like_this_message_not_found(self, generator, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MessageNotExistsError): + generator.generate_more_like_this( + app_model=_build_app_model(), + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_disabled(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False}) + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_app_model_config_missing(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = None + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_message_config_none(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock(app_model_config=None) + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_success(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock() + message.message_files = [{"id": "f"}] + message.inputs = {"a": 1} + message.query = "q" + + app_model_config = MagicMock() + app_model_config.to_dict.return_value = { + "model": {"completion_params": {"temperature": 0.1}}, + "file_upload": {"enabled": True}, + } + message.app_model_config = app_model_config + + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + stream=True, + ) + + assert result == "converted" + override_dict = get_app_config.call_args.kwargs["override_config_dict"] + assert override_dict["model"]["completion_params"]["temperature"] == 0.9 + + @pytest.mark.parametrize( + ("error", "should_publish"), + [ + (GenerateTaskStoppedError(), False), + (InvokeAuthorizationError("bad"), True), + ( + ValidationError.from_exception_data( + "Model", + [{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}], + ), + True, + ), + (ValueError("bad"), True), + (RuntimeError("boom"), True), + ], + ) + def test_generate_worker_error_handling(self, generator, mocker, error, should_publish): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(generator, "_get_message", return_value=MagicMock()) + + runner_instance = MagicMock() + runner_instance.run.side_effect = error + mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=MagicMock(), + queue_manager=queue_manager, + message_id="msg", + ) + + assert queue_manager.publish_error.called is should_publish diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py new file mode 100644 index 0000000000..cf473dfbeb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py @@ -0,0 +1,153 @@ +from collections.abc import Generator + +from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestCompletionAppGenerateResponseConverter: + def test_convert_blocking_full_response(self): + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata={"k": "v"}, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["task_id"] == "task" + assert result["message_id"] == "msg" + assert result["answer"] == "answer" + assert result["metadata"] == {"k": "v"} + + def test_convert_blocking_simple_response_metadata_simplified(self): + metadata = { + "retriever_resources": [ + { + "segment_id": "s", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "c", + "summary": "sum", + "extra": "x", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + } + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata=metadata, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s" + assert "extra" not in result["metadata"]["retriever_resources"][0] + + def test_convert_blocking_simple_response_metadata_not_dict(self): + data = CompletionAppBlockingResponse.Data.model_construct( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ) + blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + def test_convert_stream_full_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + assert result[2]["event"] == "message" + + def test_convert_stream_simple_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=MessageEndStreamResponse( + task_id="t", + id="end", + metadata={ + "retriever_resources": [ + { + "segment_id": "s", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "c", + "summary": "sum", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + }, + ), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "message_end" + assert "annotation_reply" not in result[1]["metadata"] + assert "usage" not in result[1]["metadata"] + assert result[2]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py new file mode 100644 index 0000000000..5d4c9bcde0 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py @@ -0,0 +1,55 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.pipeline.pipeline_config_manager as module +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from models.model import AppMode + + +def test_get_pipeline_config(mocker): + pipeline = MagicMock(tenant_id="tenant", id="pipe1") + workflow = MagicMock(id="wf1") + + mocker.patch.object( + module.WorkflowVariablesConfigManager, + "convert_rag_pipeline_variable", + return_value=["var1"], + ) + mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start") + + assert result.tenant_id == "tenant" + assert result.app_id == "pipe1" + assert result.workflow_id == "wf1" + assert result.app_mode == AppMode.RAG_PIPELINE + assert result.rag_pipeline_variables == ["var1"] + + +def test_config_validate_filters_related_keys(mocker): + config = { + "file_upload": {"enabled": True}, + "tts": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = PipelineConfigManager.config_validate("tenant", config) + + assert set(filtered.keys()) == {"file_upload", "tts", "moderation"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py new file mode 100644 index 0000000000..94ed8166b9 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -0,0 +1,111 @@ +from collections.abc import Generator + +from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +def test_convert_blocking_full_and_simple_response(): + blocking = WorkflowAppBlockingResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowAppBlockingResponse.Data( + id="id", + workflow_id="wf", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"k": "v"}, + error=None, + elapsed_time=0.1, + total_tokens=10, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert full == simple + assert full["workflow_run_id"] == "run" + assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED + + +def test_convert_stream_full_response(): + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + workflow_run_id="run", + ) + yield WorkflowAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + workflow_run_id="run", + ) + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + + +def test_convert_stream_simple_response_node_ignore_details(): + node_start = NodeStartStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + process_data=None, + process_data_truncated=False, + outputs={"b": 2}, + outputs_truncated=False, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=0.1, + execution_metadata=None, + created_at=1, + finished_at=2, + files=[], + ), + ) + + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run") + yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run") + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0]["event"] == "node_started" + assert result[0]["data"]["inputs"] is None + assert result[1]["event"] == "node_finished" + assert result[1]["data"]["inputs"] is None diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py new file mode 100644 index 0000000000..06face41fe --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -0,0 +1,699 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock + +import pytest + +import core.app.apps.pipeline.pipeline_generator as module +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceProviderType + + +class FakeRagPipelineGenerateEntity(SimpleNamespace): + class SingleIterationRunEntity(SimpleNamespace): + pass + + class SingleLoopRunEntity(SimpleNamespace): + pass + + def model_dump(self): + return dict(self.__dict__) + + +@pytest.fixture +def generator(mocker): + gen = module.PipelineGenerator() + + mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity) + mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs) + mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock())) + mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock())) + + return gen + + +def _build_pipeline_dataset(): + return SimpleNamespace( + id="ds", + name="dataset", + description="desc", + chunk_structure="chunk", + built_in_field_enabled=True, + tenant_id="tenant", + ) + + +def _build_pipeline(): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + pipeline.retrieve_dataset.return_value = _build_pipeline_dataset() + return pipeline + + +def _build_workflow(): + return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant") + + +def _build_user(): + return MagicMock(id="user", name="User", session_id="session") + + +def _build_args(): + return { + "inputs": {"k": "v"}, + "start_node_id": "start", + "datasource_type": DatasourceProviderType.LOCAL_FILE.value, + "datasource_info_list": [{"name": "file"}], + } + + +def _patch_session(mocker, session): + mocker.patch.object(module, "Session", return_value=session) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + +def _dummy_preserve(*args, **kwargs): + return contextlib.nullcontext() + + +class DummySession: + def __init__(self): + self.scalar = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +def test_generate_dataset_missing(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.generate( + pipeline=pipeline, + workflow=_build_workflow(), + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + +def test_generate_debugger_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + datasource_info_list = [{"name": "file1"}, {"name": "file2"}] + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=datasource_info_list, + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1) + + document1 = SimpleNamespace( + id="doc1", + position=1, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file1", + indexing_status="", + error=None, + enabled=True, + ) + document2 = SimpleNamespace( + id="doc2", + position=2, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file2", + indexing_status="", + error=None, + enabled=True, + ) + mocker.patch.object(generator, "_build_document", side_effect=[document1, document2]) + + mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock()) + + db_session = MagicMock() + mocker.patch.object(module.db, "session", db_session) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + task_proxy = MagicMock() + mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=False, + ) + + assert result["batch"] + assert len(result["documents"]) == 2 + task_proxy.delay.assert_called_once() + + +def test_generate_is_retry_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=True, + is_retry=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_worker_handles_errors(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + runner_instance.run.side_effect = ValueError("bad") + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + queue_manager.publish_error.assert_called_once() + + +def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session" + + +def test_generate_raises_when_workflow_not_found(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + +def test_generate_success_returns_converted(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", session) + + queue_manager = MagicMock() + mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager) + + worker_thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=worker_thread) + + mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock()) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + assert result == "converted" + + +def test_single_iteration_generate_validates_inputs(generator, mocker): + with pytest.raises(ValueError): + generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {}) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + _build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None} + ) + + +def test_single_iteration_generate_dataset_required(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + ) + + +def test_single_iteration_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_single_loop_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_loop_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + app_entity = FakeRagPipelineGenerateEntity(task_id="t") + + task_pipeline = MagicMock() + task_pipeline.process.side_effect = ValueError("I/O operation on closed file.") + mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=app_entity, + workflow=workflow, + queue_manager=MagicMock(), + user=_build_user(), + draft_var_saver_factory=MagicMock(), + stream=False, + ) + + +def test_build_document_sets_metadata_for_builtin_fields(generator, mocker): + class DummyDocument(SimpleNamespace): + pass + + mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs)) + + document = generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=True, + datasource_type=DatasourceProviderType.LOCAL_FILE, + datasource_info={"name": "file"}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + assert document.name == "file" + assert document.doc_metadata + + +def test_build_document_invalid_datasource_type(generator): + with pytest.raises(ValueError): + generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=False, + datasource_type="invalid", + datasource_info={}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + +def test_format_datasource_info_list_non_online_drive(generator): + result = generator._format_datasource_info_list( + DatasourceProviderType.LOCAL_FILE, + [{"name": "file"}], + _build_pipeline(), + _build_workflow(), + "start", + _build_user(), + ) + + assert result == [{"name": "file"}] + + +def test_format_datasource_info_list_missing_node_data(generator): + workflow = MagicMock(graph_dict={"nodes": []}) + + with pytest.raises(ValueError): + generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + +def test_format_datasource_info_list_online_drive_folder(generator, mocker): + workflow = MagicMock( + graph_dict={ + "nodes": [ + { + "id": "start", + "data": { + "plugin_id": "p", + "provider_name": "provider", + "datasource_name": "drive", + "credential_id": "cred", + }, + } + ] + } + ) + + runtime = MagicMock() + runtime.runtime = SimpleNamespace(credentials=None) + runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE + + mocker.patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=runtime, + ) + mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"}) + + mocker.patch.object( + generator, + "_get_files_in_folder", + side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}), + ) + + result = generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + assert result == [{"id": "f"}] + + +def test_get_files_in_folder_recurses_and_collects(generator): + class File: + def __init__(self, id, name, type): + self.id = id + self.name = name + self.type = type + + class FilesPage: + def __init__(self, files, is_truncated=False, next_page_parameters=None): + self.files = files + self.is_truncated = is_truncated + self.next_page_parameters = next_page_parameters + + class Result: + def __init__(self, result): + self.result = result + + class Runtime: + def __init__(self): + self.calls = [] + + def datasource_provider_type(self): + return DatasourceProviderType.ONLINE_DRIVE + + def online_drive_browse_files(self, user_id, request, provider_type): + self.calls.append(request.next_page_parameters) + if request.prefix == "fd": + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + if request.next_page_parameters is None: + return iter( + [ + Result( + [FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})] + ) + ] + ) + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + + runtime = Runtime() + all_files = [] + + generator._get_files_in_folder( + datasource_runtime=runtime, + prefix="root", + bucket="b", + user_id="user", + all_files=all_files, + datasource_info={}, + ) + + assert {f["id"] for f in all_files} == {"f1", "f2"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py new file mode 100644 index 0000000000..72f7552bd1 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -0,0 +1,57 @@ +import pytest + +import core.app.apps.pipeline.pipeline_queue_manager as module +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult + + +def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER) + + manager.stop_listen.assert_called_once() + + +def test_publish_stop_events_trigger_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + for event in [ + QueueErrorEvent(error=ValueError("bad")), + QueueMessageEndEvent(llm_result=LLMResult.model_construct()), + QueueWorkflowSucceededEvent(), + QueueWorkflowFailedEvent(error="failed", exceptions_count=1), + QueueWorkflowPartialSuccessEvent(exceptions_count=1), + ]: + manager.stop_listen.reset_mock() + manager._publish(event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_called_once() + + +def test_publish_non_stop_event_no_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent) + manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py new file mode 100644 index 0000000000..eec95b7f39 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -0,0 +1,297 @@ +""" +Unit tests for PipelineRunner behavior. +Asserts correct event handling, error propagation, and user invocation logic. +Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies. +Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities. +""" + +"""Unit tests for PipelineRunner behavior. + +This module validates core control-flow outcomes for +``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph +initialization guards, invoke-source to user-source resolution, and failed-run +event handling. Invariants asserted here include strict graph-config +validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing +error paths driven by ``GraphRunFailedEvent`` through mocked collaborators. +Primary collaborators include ``PipelineRunner``, +``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``, +``UserFrom``, and patched DB/runtime dependencies used by the runner. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.pipeline.pipeline_runner as module +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.graph_events import GraphRunFailedEvent + + +def _build_app_generate_entity() -> SimpleNamespace: + app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant") + return SimpleNamespace( + app_config=app_config, + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + trace_manager=MagicMock(), + inputs={"input1": "v1"}, + files=[], + workflow_execution_id="run", + document_id="doc", + original_document_id=None, + batch="batch", + dataset_id="ds", + datasource_type="local_file", + datasource_info={"name": "file"}, + start_node_id="start", + call_depth=0, + single_iteration_run=None, + single_loop_run=None, + ) + + +@pytest.fixture +def runner(): + app_generate_entity = _build_app_generate_entity() + queue_manager = MagicMock() + variable_loader = MagicMock() + workflow = MagicMock() + workflow_execution_repository = MagicMock() + workflow_node_execution_repository = MagicMock() + + return PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=queue_manager, + variable_loader=variable_loader, + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + +def test_get_app_id(runner): + assert runner._get_app_id() == "pipe" + + +def test_get_workflow_returns_workflow(mocker, runner): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + workflow = MagicMock(id="wf") + + query = MagicMock() + query.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query))) + + result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") + + assert result == workflow + + +def test_init_rag_pipeline_graph_invalid_config(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={}) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": "bad", "edges": []} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": [], "edges": "bad"} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_init_rag_pipeline_graph_not_found(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []}) + mocker.patch.object(module.Graph, "init", return_value=None) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_update_document_status_on_failure(mocker, runner): + document = MagicMock() + + query = MagicMock() + query.where.return_value.first.return_value = document + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + event = GraphRunFailedEvent(error="boom") + + runner._update_document_status(event, document_id="doc", dataset_id="ds") + + assert document.indexing_status == "error" + assert document.error == "boom" + session.commit.assert_called_once() + + +def test_run_pipeline_not_found(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.invoke_from = InvokeFrom.WEB_APP + app_generate_entity.single_iteration_run = None + app_generate_entity.single_loop_run = None + + query = MagicMock() + query.where.return_value.first.return_value = None + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_workflow_not_initialized(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + session = MagicMock() + session.query.return_value = query_pipeline + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + runner.get_workflow = MagicMock(return_value=None) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_single_iteration_path(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.single_iteration_run = MagicMock() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock( + return_value=MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={}, + type="rag-pipeline", + version="v1", + ) + ) + runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state")) + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [MagicMock()] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._prepare_single_node_execution.assert_called_once() + runner._handle_event.assert_called() + + +def test_run_normal_path_builds_graph(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + workflow = MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={"nodes": [], "edges": []}, + environment_variables=[], + rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}], + type="rag-pipeline", + version="v1", + ) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock(return_value=workflow) + runner._init_rag_pipeline_graph = MagicMock(return_value="graph") + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + mocker.patch.object( + module.RAGPipelineVariable, + "model_validate", + return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), + ) + mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._init_rag_pipeline_graph.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 43a97ae098..8f1baaa1e4 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest from core.app.apps.base_app_generator import BaseAppGenerator @@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default(): ) assert result is None + + +class TestBaseAppGeneratorExtras: + def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch): + base_app_generator = BaseAppGenerator() + + variables = [ + VariableEntity( + variable="file", + label="file", + type=VariableEntityType.FILE, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="file_list", + label="file_list", + type=VariableEntityType.FILE_LIST, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="json", + label="json", + type=VariableEntityType.JSON_OBJECT, + required=False, + ), + ] + + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mapping", + lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + ) + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mappings", + lambda mappings, tenant_id, config: ["file-1", "file-2"], + ) + + user_inputs = { + "file": {"id": "file-id"}, + "file_list": [{"id": "file-1"}, {"id": "file-2"}], + "json": {"key": "value"}, + } + + prepared = base_app_generator._prepare_user_inputs( + user_inputs=user_inputs, + variables=variables, + tenant_id="tenant-id", + ) + + assert prepared["file"] == "file-object" + assert prepared["file_list"] == ["file-1", "file-2"] + assert prepared["json"] == {"key": "value"} + + def test_prepare_user_inputs_rejects_invalid_dict_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": {"unexpected": "dict"}}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_prepare_user_inputs_rejects_invalid_list_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": [{"unexpected": "dict"}]}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_convert_to_event_stream(self): + base_app_generator = BaseAppGenerator() + + assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True} + + def _gen(): + yield {"delta": "hi"} + yield "ping" + + converted = list(base_app_generator.convert_to_event_stream(_gen())) + + assert converted[0].startswith("data: ") + assert "\n\n" in converted[0] + assert converted[1] == "event: ping\n\n" + + def test_get_draft_var_saver_factory_debugger(self): + from core.app.entities.app_invoke_entities import InvokeFrom + from dify_graph.enums import NodeType + from models import Account + + base_app_generator = BaseAppGenerator() + account = Account(name="Tester", email="tester@example.com") + account.id = "account-id" + account.tenant_id = "tenant-id" + + factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) + saver = factory( + session=MagicMock(), + app_id="app-id", + node_id="node-id", + node_type=NodeType.START, + node_execution_id="node-exec-id", + ) + + assert saver is not None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py new file mode 100644 index 0000000000..c6dc20ffc6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent + + +class DummyQueueManager(AppQueueManager): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.published = [] + + def _publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestBaseAppQueueManager: + def test_init_requires_user_id(self): + with pytest.raises(ValueError): + DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API) + + def test_publish_error_records_event(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE) + + assert isinstance(manager.published[0][0], QueueErrorEvent) + + def test_set_stop_flag_checks_user(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.get.return_value = b"end-user-u1" + AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1") + + mock_redis.setex.assert_called_once() + + def test_set_stop_flag_no_user_check(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + AppQueueManager.set_stop_flag_no_user_check(task_id="t1") + + mock_redis.setex.assert_called_once() + + def test_is_stopped_reads_cache(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = b"1" + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + assert manager._is_stopped() is True + + def test_check_for_sqlalchemy_models_raises(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + bad = SimpleNamespace(_sa_instance_state=True) + with pytest.raises(TypeError): + manager._check_for_sqlalchemy_models(bad) diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py new file mode 100644 index 0000000000..aabeb54553 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from models.model import AppMode + + +class _DummyParameterRule: + def __init__(self, name: str, use_template: str | None = None) -> None: + self.name = name + self.use_template = use_template + + +class _QueueRecorder: + def __init__(self) -> None: + self.events: list[object] = [] + + def publish(self, event, pub_from): + _ = pub_from + self.events.append(event) + + +class TestAppRunner: + def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80), + ) + + runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")]) + + assert model_config.parameters["max_tokens"] == 20 + + def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10), + ) + + assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1 + + def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True) + + monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None) + + runner.direct_output( + queue_manager=queue, + app_generate_entity=app_generate_entity, + prompt_messages=[], + text="hi", + stream=True, + ) + + assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_direct_publishes_end_event(self): + runner = AppRunner() + queue = _QueueRecorder() + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + runner._handle_invoke_result( + invoke_result=llm_result, + queue_manager=queue, + stream=False, + ) + + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_invalid_type_raises(self): + runner = AppRunner() + queue = _QueueRecorder() + + with pytest.raises(NotImplementedError): + runner._handle_invoke_result( + invoke_result=["unexpected"], + queue_manager=queue, + stream=True, + ) + + def test_organize_prompt_messages_simple_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=["STOP"]) + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hello", + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.SimplePromptTransform.get_prompt", + lambda self, **kwargs: (["simple-message"], ["simple-stop"]), + ) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["simple-message"] + assert stop == ["simple-stop"] + + def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="completion", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="answer", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"), + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-completion-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-completion-message"] + assert stop == [""] + memory_config = captured["memory_config"] + assert memory_config.role_prefix.user == "U" + assert memory_config.role_prefix.assistant == "A" + + def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-chat-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-chat-message"] + assert stop == [""] + assert len(captured["prompt_template"]) == 2 + + def test_organize_prompt_messages_advanced_missing_templates_raise(self): + runner = AppRunner() + + with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="completion", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="chat", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + warning_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger) + + image_content = ImagePromptMessageContent( + url="https://example.com/image.png", format="png", mime_type="image/png" + ) + + def _stream(): + yield LLMResultChunk( + model="stream-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage.model_construct( + content=[ + "a", + TextPromptMessageContent(data="b"), + SimpleNamespace(data="c"), + image_content, + ] + ), + ), + ) + + runner._handle_invoke_result( + invoke_result=_stream(), + queue_manager=queue, + stream=True, + agent=False, + ) + + assert isinstance(queue.events[0], QueueLLMChunkEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.message.content == "abc" + warning_logger.assert_called_once() + + def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + exception_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger) + + monkeypatch.setattr( + runner, + "_handle_multimodal_image_content", + MagicMock(side_effect=RuntimeError("failed to save image")), + ) + usage = LLMUsage.empty_usage() + + def _stream(): + yield LLMResultChunk( + model="agent-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + ImagePromptMessageContent( + url="https://example.com/image.png", + format="png", + mime_type="image/png", + ), + TextPromptMessageContent(data="done"), + ] + ), + usage=usage, + ), + ) + + runner._handle_invoke_result_stream( + invoke_result=_stream(), + queue_manager=queue, + agent=True, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + ) + + assert isinstance(queue.events[0], QueueAgentMessageEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.usage == usage + exception_logger.assert_called_once() + + def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch): + runner = AppRunner() + + class _ToggleBool: + def __init__(self, values: list[bool]): + self._values = values + self._index = 0 + + def __bool__(self): + value = self._values[min(self._index, len(self._values) - 1)] + self._index += 1 + return value + + content = SimpleNamespace( + url=_ToggleBool([False, False]), + base64_data=_ToggleBool([True, False]), + mime_type="image/png", + ) + + db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session)) + + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock()) + + runner._handle_multimodal_image_content( + content=content, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + queue_manager=queue_manager, + ) + + db_session.add.assert_not_called() + queue_manager.publish.assert_not_called() + + def test_check_hosting_moderation_direct_output_called(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(stream=False) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.HostingModerationFeature.check", + lambda self, application_generate_entity, prompt_messages: True, + ) + direct_output = MagicMock() + monkeypatch.setattr(runner, "direct_output", direct_output) + + result = runner.check_hosting_moderation( + application_generate_entity=app_generate_entity, + queue_manager=queue, + prompt_messages=[], + ) + + assert result is True + assert direct_output.called + + def test_fill_in_inputs_from_external_data_tools(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.ExternalDataFetch.fetch", + lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"}, + ) + + result = runner.fill_in_inputs_from_external_data_tools( + tenant_id="tenant", + app_id="app", + external_data_tools=[], + inputs={}, + query="q", + ) + + assert result == {"foo": "bar"} + + def test_moderation_for_inputs_returns_result(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.InputModeration.check", + lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""), + ) + app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None) + + result = runner.moderation_for_inputs( + app_id="app", + tenant_id="tenant", + app_generate_entity=app_generate_entity, + inputs={}, + query="q", + message_id="msg", + ) + + assert result == (True, {}, "") + + def test_query_app_annotations_to_reply(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.AnnotationReplyFeature.query", + lambda self, app_record, message, query, user_id, invoke_from: "reply", + ) + + response = runner.query_app_annotations_to_reply( + app_record=SimpleNamespace(), + message=SimpleNamespace(), + query="hello", + user_id="user", + invoke_from=InvokeFrom.WEB_APP, + ) + + assert response == "reply" diff --git a/api/tests/unit_tests/core/app/apps/test_exc.py b/api/tests/unit_tests/core/app/apps/test_exc.py new file mode 100644 index 0000000000..e41c78e89e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_exc.py @@ -0,0 +1,7 @@ +from core.app.apps.exc import GenerateTaskStoppedError + + +class TestAppsExceptions: + def test_generate_task_stopped_error(self): + err = GenerateTaskStoppedError("stopped") + assert str(err) == "stopped" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py index 87b8dc51e7..1250ac5ecf 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py @@ -13,9 +13,11 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.app.apps import message_based_app_generator +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from models.model import AppMode, Conversation, Message +from services.errors.app_model_config import AppModelConfigBrokenError class DummyModelConf: @@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity(): assert entity.conversation_id == "generated-conversation-id" assert entity.is_new_conversation is True assert conversation.id == "generated-conversation-id" + + +class TestMessageBasedAppGeneratorExtras: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = MessageBasedAppGenerator() + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + stream=False, + ) + + def test_get_app_model_config_requires_valid_config(self, monkeypatch): + generator = MessageBasedAppGenerator() + app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model, conversation=None) + + conversation = SimpleNamespace(app_model_config_id="missing-id") + monkeypatch.setattr( + message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None)) + ) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation) + + def test_get_conversation_introduction_handles_missing_inputs(self): + app_config = _make_app_config(AppMode.CHAT) + app_config.additional_features.opening_statement = "Hello {{name}}" + entity = _make_chat_generate_entity(app_config) + entity.inputs = {} + + generator = MessageBasedAppGenerator() + + assert generator._get_conversation_introduction(entity) == "Hello {name}" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py new file mode 100644 index 0000000000..847ad0ce9b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py @@ -0,0 +1,65 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent + + +class TestMessageBasedAppQueueManager: + def test_publish_stops_on_terminal_events(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager.stop_listen = Mock() + manager._is_stopped = Mock(return_value=False) + + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock()) + manager.stop_listen.assert_called_once() + + def test_publish_raises_when_stopped(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER) + + def test_publish_enqueues_message_end(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=False) + manager.stop_listen = Mock() + + manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + assert manager._q.qsize() == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py new file mode 100644 index 0000000000..25377e633e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode + + +class TestMessageGenerator: + def test_get_response_topic(self): + channel = Mock() + channel.topic.return_value = "topic" + + with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel): + topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1") + + assert topic == "topic" + expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1") + channel.topic.assert_called_once_with(expected_key) + + def test_retrieve_events_passes_arguments(self): + with ( + patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"), + patch( + "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) + ) as mock_stream, + ): + events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + + assert events == [{"event": "ping"}] + mock_stream.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index 7b5447c01e..a7714c56ce 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -6,6 +6,7 @@ import queue import pytest from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value with pytest.raises(StopIteration): next(generator) + + +def test_normalize_terminal_events_defaults(): + assert _normalize_terminal_events(None) == { + StreamEvent.WORKFLOW_FINISHED.value, + StreamEvent.WORKFLOW_PAUSED.value, + } + + +def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): + topic = FakeTopic() + times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] + + def fake_time(): + return times.pop(0) + + monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time) + + generator = stream_topic_events( + topic=topic, + idle_timeout=10.0, + ping_interval=1.0, + ) + + assert next(generator) == StreamEvent.PING.value + # next receive yields None -> ping interval triggers + assert next(generator) == StreamEvent.PING.value diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py new file mode 100644 index 0000000000..d8afd3b10a --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import NodeType +from dify_graph.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, +) +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable + + +class TestWorkflowBasedAppRunner: + def test_resolve_user_from(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER + + def test_init_graph_validates_graph_structure(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + with pytest.raises(ValueError, match="nodes or edges not found"): + runner._init_graph( + graph_config={}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="nodes in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": {}, "edges": []}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="edges in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": [], "edges": {}}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + def test_prepare_single_node_execution_requires_run(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + workflow = SimpleNamespace(environment_variables=[], graph_dict={}) + + with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): + runner._prepare_single_node_execution(workflow, None, None) + + def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + graph_config = { + "nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}], + "edges": [], + } + workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + lambda **kwargs: SimpleNamespace(), + ) + + class _NodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + from core.app.apps import workflow_app_runner + + monkeypatch.setitem( + workflow_app_runner.NODE_TYPE_CLASSES_MAPPING, + NodeType.START, + {"1": _NodeCls}, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="node-1", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + ) + + assert graph is not None + assert variable_pool is graph_runtime_state.variable_pool + + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append((event, publish_from)) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + graph_runtime_state.register_paused_node("node-1") + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + emails: list[dict] = [] + + class _Dispatch: + def apply_async(self, *, kwargs, queue): + emails.append({"kwargs": kwargs, "queue": queue}) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.dispatch_human_input_email_task", + _Dispatch(), + ) + + reason = HumanInputRequired( + form_id="form", + form_content="content", + node_id="node-1", + node_title="Node", + ) + + runner._handle_event(workflow_entry, GraphRunStartedEvent()) + runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True})) + runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={})) + + assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published) + assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published) + paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent)) + assert paused_event.paused_nodes == ["node-1"] + assert emails + + def test_handle_node_events_publishes_queue_events(self): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + runner._handle_event( + workflow_entry, + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + node_title="Start", + start_at=datetime.utcnow(), + ), + ) + runner._handle_event( + workflow_entry, + NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + selector=["node", "text"], + chunk="hi", + is_final=False, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunAgentLogEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + message_id="msg", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunIterationSucceededEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="Iter", + start_at=datetime.utcnow(), + inputs={}, + outputs={"ok": True}, + metadata={}, + steps=1, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunLoopFailedEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="Loop", + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + metadata={}, + steps=1, + error="boom", + ), + ) + + assert any(isinstance(event, QueueTextChunkEvent) for event in published) + assert any(isinstance(event, QueueAgentLogEvent) for event in published) + assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) + assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) diff --git a/api/tests/unit_tests/core/app/apps/workflow/__init__.py b/api/tests/unit_tests/core/app/apps/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py new file mode 100644 index 0000000000..f8dd6bf609 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from models.model import AppMode + + +class TestWorkflowAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.WORKFLOW + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + # Support both positional and keyword arguments for config + if "config" in kwargs: + config = kwargs["config"] + elif len(args) > 0: + config = args[0] + else: + config = {} + config[key] = value + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 2), + ), + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 3), + ), + ): + filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["text_to_speech"] == 2 + assert filtered["sensitive_word_avoidance"] == 3 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py new file mode 100644 index 0000000000..6d6f9272cb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestWorkflowAppGeneratorValidation: + def test_should_prepare_user_inputs(self): + generator = WorkflowAppGenerator() + + assert generator._should_prepare_user_inputs({}) is True + assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False + + def test_single_iteration_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestWorkflowAppGeneratorHandleResponse: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_execution_id="run-id", + call_depth=0, + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + +class TestWorkflowAppGeneratorGenerate: + def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.file_factory.build_from_mappings", + lambda **kwargs: [], + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) + or setattr(self, "user_id", user_id) + }, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + prepare_inputs = pytest.fail + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs()) + + monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True}) + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=SimpleNamespace(id="user", session_id="session"), + args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + call_depth=0, + ) + + assert result == {"ok": True} diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py new file mode 100644 index 0000000000..6133be9867 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent + + +class TestWorkflowAppQueueManager: + def test_publish_stop_events_trigger_stop(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + manager._is_stopped = lambda: True + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER) + + def test_publish_non_stop_event_does_not_raise(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + + manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_errors.py b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py new file mode 100644 index 0000000000..7461e06833 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py @@ -0,0 +1,9 @@ +from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError + + +class TestWorkflowErrors: + def test_workflow_paused_in_blocking_mode_error_attributes(self): + err = WorkflowPausedInBlockingModeError() + assert err.error_code == "workflow_paused_in_blocking_mode" + assert err.code == 400 + assert "blocking response mode" in err.description diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py new file mode 100644 index 0000000000..62e94a7580 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -0,0 +1,133 @@ +from collections.abc import Generator + +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +class TestWorkflowGenerateResponseConverter: + def test_blocking_full_response(self): + blocking = WorkflowAppBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppBlockingResponse.Data( + id="exec-1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.2, + total_tokens=10, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + assert response["workflow_run_id"] == "r1" + + def test_stream_simple_response_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[WorkflowAppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1")) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish) + yield WorkflowAppStreamResponse( + workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")) + ) + + converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" + + def test_convert_stream_simple_response_handles_ping_and_nodes(self): + def _gen(): + yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task")) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeStartStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + created_at=1, + ), + ), + ) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + created_at=1, + finished_at=2, + elapsed_time=1.0, + error=None, + ), + ), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen())) + + assert chunks[0] == "ping" + assert chunks[1]["event"] == "node_started" + assert chunks[2]["event"] == "node_finished" + + def test_convert_stream_full_response_handles_error(self): + def _gen(): + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen())) + + assert chunks[0]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..b37f7a8120 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -0,0 +1,868 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + PingStreamResponse, + WorkflowFinishStreamResponse, + WorkflowPauseStreamResponse, + WorkflowStartStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import CreatorUserRole +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = SimpleNamespace(id="user", session_id="session") + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestWorkflowGenerateTaskPipeline: + def test_to_blocking_response_handles_pause(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.status == WorkflowExecutionStatus.PAUSED + + def test_to_blocking_response_handles_finish(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowFinishStreamResponse.Data( + id="run", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.0, + total_tokens=5, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.outputs == {"ok": True} + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_execution_id == "run-id" + assert responses == ["started"] + + def test_handle_node_succeeded_event_saves_output(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + pipeline._workflow_execution_id = "run-id" + + event = QueueNodeSucceededEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + + def test_handle_workflow_failed_event_yields_error(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list( + pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1)) + ) + + assert responses[0] == "finish" + + def test_handle_text_chunk_event_publishes_tts(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses[0].data.text == "hi" + assert published == [queue_message] + + def test_dispatch_event_handles_node_failed(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + + event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["done"] + + def test_handle_stop_event_yields_finish(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + + responses = list( + pipeline._handle_workflow_failed_and_stop_events( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + ) + ) + + assert responses == ["finish"] + + def test_save_workflow_app_log_created_from(self): + pipeline = _make_pipeline() + pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API + pipeline._user_id = "user" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + + assert added + + def test_iteration_loop_and_human_input_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter" + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done" + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=NodeType.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + agent_event = QueueAgentLogEvent( + id="log", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + node_id="node", + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"] + + def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return AudioTrunk(status="stream", audio="data") + if self.calls == 2: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + return None + + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_init_with_end_user_sets_role_and_system_user(self): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None) + end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id") + end_user.id = "end-user-id" + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=end_user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + assert pipeline._created_by_role == CreatorUserRole.END_USER + assert pipeline._workflow_system_variables.user_id == "session-id" + + def test_process_returns_stream_and_blocking_variants(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.stream = True + pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + stream_response = list(pipeline.process()) + assert len(stream_response) == 1 + assert stream_response[0].workflow_run_id is None + + pipeline._base_task_pipeline.stream = False + pipeline._wrapper_process_stream_response = lambda **kwargs: iter( + [ + WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowFinishStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + created_at=1, + finished_at=2, + ), + ) + ] + ) + + blocking_response = pipeline.process() + assert blocking_response.workflow_run_id == "run-id" + + def test_to_blocking_response_handles_error_and_unexpected_end(self): + pipeline = _make_pipeline() + + def _error_gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("boom")) + + with pytest.raises(ValueError, match="boom"): + pipeline._to_blocking_response(_error_gen()) + + def _unexpected_gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(ValueError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_unexpected_gen()) + + def test_to_stream_response_tracks_workflow_run_id(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowStartStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowStartStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + inputs={}, + created_at=1, + ), + ) + yield PingStreamResponse(task_id="task") + + stream_responses = list(pipeline._to_stream_response(_gen())) + assert stream_responses[0].workflow_run_id == "run-id" + assert stream_responses[1].workflow_run_id == "run-id" + + def test_listen_audio_msg_returns_none_without_publisher(self): + pipeline = _make_pipeline() + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts(self): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = {} + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + responses = list(pipeline._wrapper_process_stream_response()) + assert responses == [PingStreamResponse(task_id="task")] + + def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + sleep_spy = [] + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + time_values = iter([0.0, 0.0, 0.2]) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values)) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True) + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert sleep_spy + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.called = False + + def check_and_get_audio(self): + if not self.called: + self.called = True + raise RuntimeError("tts failure") + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + logger_exception = [] + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.logger.exception", + lambda *args, **kwargs: logger_exception.append((args, kwargs)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert logger_exception + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_database_session_rolls_back_on_error(self, monkeypatch): + pipeline = _make_pipeline() + calls = {"commit": 0, "rollback": 0} + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + calls["commit"] += 1 + + def rollback(self): + calls["rollback"] += 1 + + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + with pytest.raises(RuntimeError, match="db error"): + with pipeline._database_session(): + raise RuntimeError("db error") + + assert calls["commit"] == 0 + assert calls["rollback"] == 1 + + def test_node_retry_and_started_handlers_cover_none_and_value(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + + retry_event = QueueNodeRetryEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=NodeType.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + error="error", + retry_index=1, + ) + started_event = QueueNodeStartedEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=NodeType.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + ) + + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_retry_event(retry_event)) == [] + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry" + assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"] + + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_started_event(started_event)) == [] + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started" + assert list(pipeline._handle_node_started_event(started_event)) == ["started"] + + def test_handle_node_exception_event_saves_output(self): + pipeline = _make_pipeline() + saved_ids: list[str] = [] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id) + + event = QueueNodeExceptionEvent( + node_execution_id="exec-id", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="boom", + ) + + responses = list(pipeline._handle_node_failed_events(event)) + assert responses == ["failed"] + assert saved_ids == ["exec-id"] + + def test_success_partial_and_pause_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"] + assert list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={}) + ) + ) == ["finish"] + + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [ + "pause-a", + "pause-b", + ] + pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"]) + assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"] + + def test_text_chunk_handler_returns_empty_when_text_missing(self): + pipeline = _make_pipeline() + event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None) + assert list(pipeline._handle_text_chunk_event(event)) == [] + + def test_dispatch_event_direct_failed_and_unhandled_paths(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) + assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"] + + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"]) + assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [ + "workflow-failed" + ] + + assert list(pipeline._dispatch_event(SimpleNamespace())) == [] + + def test_process_stream_response_main_match_paths_and_cleanup(self): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=QueueWorkflowStartedEvent()), + SimpleNamespace(event=QueueTextChunkEvent(text="hello")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueErrorEvent(error="e")), + ] + ) + pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"]) + pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"]) + pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"]) + pipeline._handle_error_event = lambda event, **kwargs: iter(["error"]) + publisher_calls: list[object] = [] + + class _Publisher: + def publish(self, message): + publisher_calls.append(message) + + responses = list(pipeline._process_stream_response(tts_publisher=_Publisher())) + assert responses == ["started", "text", "dispatched", "error"] + assert publisher_calls == [None] + + def test_process_stream_response_break_paths(self): + pipeline = _make_pipeline() + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"]) + assert list(pipeline._process_stream_response()) == ["failed"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))] + ) + pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"]) + assert list(pipeline._process_stream_response()) == ["paused"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"]) + assert list(pipeline._process_stream_response()) == ["stopped"] + + def test_save_workflow_app_log_covers_invoke_from_variants(self): + pipeline = _make_pipeline() + pipeline._user_id = "user-id" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "installed-app" + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "web-app" + + count_before = len(added) + pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert len(added) == count_before + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) + assert len(added) == count_before + + def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + pipeline = _make_pipeline() + saver_calls: list[tuple[object, object]] = [] + captured_factory_args: dict[str, object] = {} + + class _Saver: + def save(self, process_data, outputs): + saver_calls.append((process_data, outputs)) + + def _factory(**kwargs): + captured_factory_args.update(kwargs) + return _Saver() + + class _Begin: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb): + return False + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return _Begin() + + pipeline._draft_var_saver_factory = _factory + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + event = QueueNodeSucceededEvent( + node_execution_id="exec-id", + node_id="node-id", + node_type=NodeType.START, + in_loop_id="loop-id", + start_at=datetime.utcnow(), + process_data={"k": "v"}, + outputs={"out": 1}, + ) + pipeline._save_output_for_event(event=event, node_execution_id="exec-id") + + assert captured_factory_args["node_execution_id"] == "exec-id" + assert captured_factory_args["enclosing_node_id"] == "loop-id" + assert saver_calls == [({"k": "v"}, {"out": 1})]