From b7a5ed6c0b2fd846a10b1952d6a3a272c0944347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Wed, 25 Mar 2026 19:46:47 +0800 Subject: [PATCH] test(api): cover remaining workflow typing branches --- ...st_base_app_generate_response_converter.py | 110 ++++++++++++++++++ .../workflow/nodes/agent/test_entities.py | 46 ++++++++ .../core/workflow/nodes/llm/test_node.py | 49 ++++++++ .../workflow/nodes/tool/test_tool_node.py | 66 ++++++++++- .../workflow/test_workflow_entry_helpers.py | 16 +++ 5 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_entities.py diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py new file mode 100644 index 0000000000..58cac29042 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py @@ -0,0 +1,110 @@ +from collections.abc import Iterator + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import AppBlockingResponse +from core.errors.error import QuotaExceededError + + +class DummyResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = AppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]: + return {"mode": "blocking-full", "task_id": blocking_response.task_id} + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]: + return {"mode": "blocking-simple", "task_id": blocking_response.task_id} + + @classmethod + def convert_stream_full_response(cls, stream_response: Iterator[object]): + for _ in stream_response: + yield {"mode": "stream-full"} + + @classmethod + def convert_stream_simple_response(cls, stream_response: Iterator[object]): + for _ in stream_response: + yield {"mode": "stream-simple"} + + +def test_convert_routes_to_full_or_simple_modes() -> None: + blocking = AppBlockingResponse(task_id="task-1") + + assert DummyResponseConverter.convert(blocking, InvokeFrom.DEBUGGER) == { + "mode": "blocking-full", + "task_id": "task-1", + } + assert DummyResponseConverter.convert(blocking, InvokeFrom.WEB_APP) == { + "mode": "blocking-simple", + "task_id": "task-1", + } + assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.SERVICE_API)) == [{"mode": "stream-full"}] + assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.WEB_APP)) == [{"mode": "stream-simple"}] + + +def test_get_simple_metadata_preserves_new_retriever_fields() -> None: + metadata = { + "retriever_resources": [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset", + "document_id": "document-1", + "segment_id": "segment-1", + "position": 1, + "data_source_type": "upload_file", + "document_name": "Document", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "hash", + "content": "content", + "page": 5, + "title": "Title", + "files": [{"id": "file-1"}], + "summary": "summary", + } + ], + "annotation_reply": "hidden", + "usage": {"latency": 0.1}, + } + + result = DummyResponseConverter._get_simple_metadata(metadata) + + assert result == { + "retriever_resources": [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset", + "document_id": "document-1", + "segment_id": "segment-1", + "position": 1, + "data_source_type": "upload_file", + "document_name": "Document", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "hash", + "content": "content", + "page": 5, + "title": "Title", + "files": [{"id": "file-1"}], + "summary": "summary", + } + ] + } + + +def test_error_to_stream_response_uses_specific_and_fallback_mappings() -> None: + quota_response = DummyResponseConverter._error_to_stream_response(QuotaExceededError()) + fallback_response = DummyResponseConverter._error_to_stream_response(RuntimeError("boom")) + + assert quota_response["code"] == "provider_quota_exceeded" + assert quota_response["status"] == 400 + assert fallback_response == { + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, + } diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_entities.py new file mode 100644 index 0000000000..dd3de6021b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_entities.py @@ -0,0 +1,46 @@ +import pytest +from pydantic import ValidationError + +from core.workflow.nodes.agent.entities import AgentNodeData + + +def test_agent_input_accepts_variable_selector_and_mixed_values() -> None: + node_data = AgentNodeData.model_validate( + { + "title": "Agent", + "agent_strategy_provider_name": "provider", + "agent_strategy_name": "strategy", + "agent_strategy_label": "Strategy", + "agent_parameters": { + "query": {"type": "variable", "value": ["start", "query"]}, + "tools": {"type": "mixed", "value": [{"provider": "builtin", "name": "search"}]}, + }, + } + ) + + assert node_data.agent_parameters["query"].value == ["start", "query"] + assert node_data.agent_parameters["tools"].value == [{"provider": "builtin", "name": "search"}] + + +def test_agent_input_rejects_invalid_variable_selector_and_unknown_type() -> None: + with pytest.raises(ValidationError): + AgentNodeData.model_validate( + { + "title": "Agent", + "agent_strategy_provider_name": "provider", + "agent_strategy_name": "strategy", + "agent_strategy_label": "Strategy", + "agent_parameters": {"query": {"type": "variable", "value": "start.query"}}, + } + ) + + with pytest.raises(ValidationError, match="Unknown agent input type"): + AgentNodeData.model_validate( + { + "title": "Agent", + "agent_strategy_provider_name": "provider", + "agent_strategy_name": "strategy", + "agent_strategy_label": "Strategy", + "agent_parameters": {"query": {"type": "unsupported", "value": "hello"}}, + } + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 22ffe64f9f..851990b6c8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -15,6 +15,7 @@ from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities import LLMMode from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultWithStructuredOutput, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -120,6 +121,54 @@ def test_prompt_config_converts_none_jinja_variables() -> None: assert prompt_config.jinja2_variables == [] +def test_fetch_structured_output_schema_validates_required_object_shape() -> None: + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object", "a": 1}}) == { + "type": "object", + "a": 1, + } + + with pytest.raises(Exception, match="valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": None}) + + +def test_handle_blocking_result_separates_reasoning_and_structured_output() -> None: + saver = mock.MagicMock(spec=LLMFileSaver) + event = LLMNode.handle_blocking_result( + invoke_result=LLMResultWithStructuredOutput( + model="gpt", + message=AssistantPromptMessage(content="reasoninganswer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "done"}, + ), + saver=saver, + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "done"} + assert event.usage.latency == 1.234 + + +def test_handle_blocking_result_keeps_tagged_text_without_structured_output() -> None: + saver = mock.MagicMock(spec=LLMFileSaver) + event = LLMNode.handle_blocking_result( + invoke_result=LLMResult( + model="gpt", + message=AssistantPromptMessage(content="plain text"), + usage=LLMUsage.empty_usage(), + ), + saver=saver, + file_outputs=[], + ) + + assert event.text == "plain text" + assert event.reasoning_content == "" + assert event.structured_output is None + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index cc732367ff..29658a6dc7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -221,3 +221,67 @@ def test_tool_node_data_filters_missing_tool_parameter_values() -> None: ) assert set(node_data.tool_parameters.keys()) == {"query"} + + +def test_generate_parameters_reads_variables_and_optional_missing_inputs(tool_node: ToolNode) -> None: + variable_pool = MagicMock() + variable_pool.get.side_effect = [MagicMock(value="from-variable"), None] + node_data = ToolNodeData.model_validate( + { + "title": "Tool", + "provider_id": "provider", + "provider_type": "builtin", + "provider_name": "provider", + "tool_name": "tool", + "tool_label": "tool", + "tool_configurations": {}, + "tool_parameters": { + "query": {"type": "variable", "value": ["start", "query"]}, + "optional": {"type": "variable", "value": ["start", "optional"]}, + }, + } + ) + tool_parameters = [ + ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True), + ToolParameter.get_simple_instance("optional", "optional", ToolParameter.ToolParameterType.STRING, False), + ] + + result = tool_node._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=variable_pool, + node_data=node_data, + ) + + assert result == {"query": "from-variable"} + + +def test_generate_parameters_formats_logs_and_unknown_parameters(tool_node: ToolNode) -> None: + variable_pool = MagicMock() + variable_pool.convert_template.return_value = MagicMock(text="rendered", log="masked") + node_data = ToolNodeData.model_validate( + { + "title": "Tool", + "provider_id": "provider", + "provider_type": "builtin", + "provider_name": "provider", + "tool_name": "tool", + "tool_label": "tool", + "tool_configurations": {}, + "tool_parameters": { + "query": {"type": "mixed", "value": "{{ question }}"}, + "missing": {"type": "constant", "value": "literal"}, + }, + } + ) + tool_parameters = [ + ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True), + ] + + result = tool_node._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=variable_pool, + node_data=node_data, + for_log=True, + ) + + assert result == {"query": "masked", "missing": None} diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index dc4c7a00c5..47a6d3e317 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -97,6 +97,22 @@ class TestWorkflowChildEngineBuilder: ((sentinel.layer_two,), {}), ] + def test_build_child_engine_tolerates_invalid_graph_shape_until_graph_init(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + + with ( + patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory), + patch.object(workflow_entry.Graph, "init", side_effect=ValueError("invalid graph")), + ): + with pytest.raises(ValueError, match="invalid graph"): + builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": "invalid"}, + root_node_id="root", + ) + class TestWorkflowEntryInit: def test_rejects_call_depth_above_limit(self):