diff --git a/api/tests/unit_tests/core/agent/conftest.py b/api/tests/unit_tests/core/agent/conftest.py new file mode 100644 index 0000000000..a2aa501720 --- /dev/null +++ b/api/tests/unit_tests/core/agent/conftest.py @@ -0,0 +1,80 @@ +import pytest + + +class DummyTool: + def __init__(self, name): + self.name = name + + +class DummyPromptEntity: + def __init__(self, first_prompt): + self.first_prompt = first_prompt + + +class DummyAgentConfig: + def __init__(self, prompt_entity=None): + self.prompt = prompt_entity + + +class DummyAppConfig: + def __init__(self, agent=None): + self.agent = agent + + +class DummyScratchpadUnit: + def __init__( + self, + final=False, + thought=None, + action_str=None, + observation=None, + agent_response=None, + ): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def dummy_tool_factory(): + def _factory(name): + return DummyTool(name) + + return _factory + + +@pytest.fixture +def dummy_prompt_entity_factory(): + def _factory(first_prompt): + return DummyPromptEntity(first_prompt) + + return _factory + + +@pytest.fixture +def dummy_agent_config_factory(): + def _factory(prompt_entity=None): + return DummyAgentConfig(prompt_entity) + + return _factory + + +@pytest.fixture +def dummy_app_config_factory(): + def _factory(agent=None): + return DummyAppConfig(agent) + + return _factory + + +@pytest.fixture +def dummy_scratchpad_unit_factory(): + def _factory(**kwargs): + return DummyScratchpadUnit(**kwargs) + + return _factory diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index ba8c903f65..9073ae1044 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -1,70 +1,255 @@ +"""Unit tests for CotAgentOutputParser. + +Verifies expected parsing behavior for streaming content and JSON payloads, +including edge cases such as empty/non-string content and malformed JSON. +Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real +model output structures. Implementation under test: +core.agent.output_parser.cot_output_parser.CotAgentOutputParser. +""" + import json -from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest -from core.agent.entities import AgentScratchpadUnit from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta -def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]: - for i in range(len(text)): - yield LLMResultChunk( - model="model", - prompt_messages=[], - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])), +@pytest.fixture +def mock_action_class(mocker): + mock_action = MagicMock() + mocker.patch( + "core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action", + mock_action, + ) + return mock_action + + +@pytest.fixture +def usage_dict(): + return {} + + +@pytest.fixture +def make_chunk(): + def _make_chunk(content=None, usage=None): + delta = SimpleNamespace( + message=SimpleNamespace(content=content), + usage=usage, ) + return SimpleNamespace(delta=delta) + + return _make_chunk -def test_cot_output_parser(): - test_cases = [ - { - "input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with json - { - "input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with JSON - { - "input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # list - { - "input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block - { - "input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block and json - {"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"}, - ] +# ============================================================ +# Test Suite +# ============================================================ - parser = CotAgentOutputParser() - usage_dict = {} - for test_case in test_cases: - # mock llm_response as a generator by text - llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"]) - results = parser.handle_react_stream_output(llm_response, usage_dict) - output = "" - for result in results: - if isinstance(result, str): - output += result - elif isinstance(result, AgentScratchpadUnit.Action): - if test_case["action"]: - assert result.to_dict() == test_case["action"] - output += json.dumps(result.to_dict()) - if test_case["output"]: - assert output == test_case["output"] + +class TestCotAgentOutputParser: + """Validate CotAgentOutputParser streaming + JSON parsing behavior. + + Lifecycle: no explicit setup/teardown; relies on pytest fixtures for + lightweight chunk/action doubles. Invariants: non-string/empty content + yields no output, usage gets recorded when provided, and valid action JSON + results in Action instantiation. Usage: invoke via pytest (e.g., + `pytest -k TestCotAgentOutputParser`). + """ + + # -------------------------------------------------------- + # Basic streaming & usage + # -------------------------------------------------------- + + def test_stream_plain_text(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("hello world")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "".join(result) == "hello world" + + def test_stream_empty_string(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_stream_none_content(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk(None)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + @pytest.mark.parametrize("content", [123, 12.5, [], {}, object()]) + def test_non_string_content(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_usage_update(self, make_chunk, usage_dict) -> None: + usage_data = {"tokens": 99} + chunks = [make_chunk("abc", usage=usage_data)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert usage_dict["usage"] == usage_data + + # -------------------------------------------------------- + # JSON parsing (direct + streaming) + # -------------------------------------------------------- + + def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '{"action": "search", "input": "query"}' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="search", action_input="query") + + def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '[{"action": "lookup", "input": "abc"}]' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None: + content = '{"foo": "bar"}' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # Expect the serialized JSON to be yielded as a single element. + assert result == [json.dumps({"foo": "bar"})] + + def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None: + content = "{invalid json}" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert any("invalid json" in str(r) for r in result) + + def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None: + chunks = [ + make_chunk('{"action": '), + make_chunk('"multi", '), + make_chunk('"input": "step"}'), + ] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="multi", action_input="step") + + def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('{"foo": "bar"')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('{"foo": "bar"' in item for item in result) + + # -------------------------------------------------------- + # Code block JSON extraction + # -------------------------------------------------------- + + def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = """```json +{"action": "lookup", "input": "abc"} +```""" + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None: + # Multiple JSON objects inside single code fence (invalid combined JSON) + # Parser should safely ignore invalid combined block + content = """```json +{"action": "a1", "input": "x"} +{"action": "a2", "input": "y"} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # No valid parsed action expected due to invalid combined JSON + assert mock_action_class.call_count == 0 + assert isinstance(result, list) + + def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None: + content = """```json +{invalid} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result + + def test_unclosed_code_block(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('```json {"a":1}')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('```json {"a":1}' in item for item in result) + + # -------------------------------------------------------- + # Action / Thought prefix handling + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + " action: something", + " ACTION: something", + " thought: reasoning", + " THOUGHT: reasoning", + ], + ) + def test_prefix_handling(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + joined = "".join(str(item) for item in result) + expected_word = "something" if "action:" in content.lower() else "reasoning" + assert expected_word in joined + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() + + def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("xaction: test")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "x" in "".join(map(str, result)) + + # -------------------------------------------------------- + # Mixed streaming scenarios + # -------------------------------------------------------- + + def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None: + content = 'start {"action": "mix", "input": "1"} end' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # JSON action should be parsed + mock_action_class.assert_called_once() + # Ensure surrounding text is streamed (character-level) + joined = "".join(str(r) for r in result if not isinstance(r, MagicMock)) + assert "start" in joined + assert "end" in joined + + def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert mock_action_class.call_count == 2 + + def test_backtick_noise(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("text with ` random ` backticks")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "text with" in "".join(result) + + # -------------------------------------------------------- + # Boundary & edge inputs + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + "```", + "{", + "}", + "```json", + "action:", + "thought:", + " ", + ], + ) + def test_edge_inputs(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + joined = "".join(result) + if content == " ": + assert result == [] or joined == content + if content in {"```", "{", "}", "```json"}: + assert content in joined + if content.lower() in {"action:", "thought:"}: + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() diff --git a/api/tests/unit_tests/core/agent/strategy/test_base.py b/api/tests/unit_tests/core/agent/strategy/test_base.py new file mode 100644 index 0000000000..83ff79e8a1 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_base.py @@ -0,0 +1,174 @@ +from collections.abc import Generator +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.base import BaseAgentStrategy + + +class DummyStrategy(BaseAgentStrategy): + """ + Concrete implementation for testing BaseAgentStrategy + """ + + def __init__(self, return_values=None, raise_exception=None): + self.return_values = return_values or [] + self.raise_exception = raise_exception + self.received_args = None + + def _invoke( + self, + params, + user_id, + conversation_id=None, + app_id=None, + message_id=None, + credentials=None, + ) -> Generator: + self.received_args = ( + params, + user_id, + conversation_id, + app_id, + message_id, + credentials, + ) + + if self.raise_exception: + raise self.raise_exception + + yield from self.return_values + + +class TestBaseAgentStrategyInstantiation: + def test_cannot_instantiate_abstract_class(self) -> None: + with pytest.raises(TypeError): + BaseAgentStrategy() + + +class TestBaseAgentStrategyInvoke: + @pytest.fixture + def mock_message(self): + return MagicMock(name="AgentInvokeMessage") + + @pytest.fixture + def mock_credentials(self): + return MagicMock(name="InvokeCredentials") + + @pytest.mark.parametrize( + ("params", "user_id", "conversation_id", "app_id", "message_id"), + [ + ({"key": "value"}, "user1", "conv1", "app1", "msg1"), + ({}, "user2", None, None, None), + ({"a": 1}, "", "", "", ""), + ({"nested": {"x": 1}}, "user3", None, "app3", None), + ], + ) + def test_invoke_success( + self, + mock_message, + mock_credentials, + params, + user_id, + conversation_id, + app_id, + message_id, + ) -> None: + # Arrange + strategy = DummyStrategy(return_values=[mock_message]) + + # Act + result = list( + strategy.invoke( + params=params, + user_id=user_id, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + credentials=mock_credentials, + ) + ) + + # Assert + assert result == [mock_message] + assert strategy.received_args == ( + params, + user_id, + conversation_id, + app_id, + message_id, + mock_credentials, + ) + + def test_invoke_multiple_yields(self, mock_message) -> None: + # Arrange + messages = [mock_message, MagicMock(), MagicMock()] + strategy = DummyStrategy(return_values=messages) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == messages + + def test_invoke_empty_generator(self) -> None: + # Arrange + strategy = DummyStrategy(return_values=[]) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == [] + + def test_invoke_propagates_exception(self) -> None: + # Arrange + strategy = DummyStrategy(raise_exception=ValueError("failure")) + + # Act & Assert + with pytest.raises(ValueError, match="failure"): + list(strategy.invoke(params={}, user_id="user")) + + @pytest.mark.parametrize( + "invalid_params", + [ + None, + "", + 123, + [], + ], + ) + def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None: + """ + Base class does not validate types — ensure pass-through behavior + """ + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params=invalid_params, user_id="user")) + + assert result == [] + + def test_invoke_none_user_id(self) -> None: + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params={}, user_id=None)) + + assert result == [] + + +class TestBaseAgentStrategyGetParameters: + def test_get_parameters_default_empty_list(self) -> None: + strategy = DummyStrategy() + result = strategy.get_parameters() + + assert isinstance(result, list) + assert result == [] + + def test_get_parameters_returns_new_list_each_time(self) -> None: + strategy = DummyStrategy() + + first = strategy.get_parameters() + second = strategy.get_parameters() + + assert first == second == [] + assert first is not second diff --git a/api/tests/unit_tests/core/agent/strategy/test_plugin.py b/api/tests/unit_tests/core/agent/strategy/test_plugin.py new file mode 100644 index 0000000000..e0894f1e90 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_plugin.py @@ -0,0 +1,272 @@ +# File: tests/unit_tests/core/agent/strategy/test_plugin.py + +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.plugin import PluginAgentStrategy + +# ============================================================ +# Fixtures +# ============================================================ + + +@pytest.fixture +def mock_parameter(): + def _factory(name="param", return_value="initialized"): + param = MagicMock() + param.name = name + param.init_frontend_parameter = MagicMock(return_value=return_value) + return param + + return _factory + + +@pytest.fixture +def mock_declaration(mock_parameter): + param1 = mock_parameter("param1", "init1") + param2 = mock_parameter("param2", "init2") + + identity = MagicMock() + identity.provider = "provider_x" + identity.name = "strategy_x" + + declaration = MagicMock() + declaration.parameters = [param1, param2] + declaration.identity = identity + + return declaration + + +@pytest.fixture +def strategy(mock_declaration): + return PluginAgentStrategy( + tenant_id="tenant_123", + declaration=mock_declaration, + meta_version="v1", + ) + + +# ============================================================ +# Initialization Tests +# ============================================================ + + +class TestPluginAgentStrategyInitialization: + def test_init_sets_attributes(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version="meta_v", + ) + + assert strategy.tenant_id == "tenant_test" + assert strategy.declaration == mock_declaration + assert strategy.meta_version == "meta_v" + + def test_init_meta_version_none(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version=None, + ) + + assert strategy.meta_version is None + + +# ============================================================ +# get_parameters Tests +# ============================================================ + + +class TestGetParameters: + def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None: + result = strategy.get_parameters() + assert result == mock_declaration.parameters + + +# ============================================================ +# initialize_parameters Tests +# ============================================================ + + +class TestInitializeParameters: + def test_initialize_parameters_success(self, strategy, mock_declaration) -> None: + params = {"param1": "value1"} + + result = strategy.initialize_parameters(params.copy()) + + assert result["param1"] == "init1" + assert result["param2"] == "init2" + + mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1") + mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None) + + @pytest.mark.parametrize( + "input_params", + [ + {}, + {"param1": None}, + {"param1": ""}, + {"param1": 0}, + {"param1": []}, + {"param1": {}, "param2": "value"}, + ], + ) + def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None: + result = strategy.initialize_parameters(input_params.copy()) + + for param in strategy.declaration.parameters: + assert param.name in result + + def test_initialize_parameters_invalid_input_type(self, strategy) -> None: + with pytest.raises(AttributeError): + strategy.initialize_parameters(None) + + +# ============================================================ +# _invoke Tests +# ============================================================ + + +class TestInvoke: + def test_invoke_success_all_arguments(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mock_convert = mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={"converted": True}, + ) + + result = list( + strategy._invoke( + params={"param1": "value"}, + user_id="user_1", + conversation_id="conv_1", + app_id="app_1", + message_id="msg_1", + credentials=None, + ) + ) + + assert result == ["msg1", "msg2"] + mock_convert.assert_called_once() + mock_manager.invoke.assert_called_once() + + call_kwargs = mock_manager.invoke.call_args.kwargs + assert call_kwargs["tenant_id"] == "tenant_123" + assert call_kwargs["user_id"] == "user_1" + assert call_kwargs["agent_provider"] == "provider_x" + assert call_kwargs["agent_strategy"] == "strategy_x" + assert call_kwargs["agent_params"] == {"converted": True} + assert call_kwargs["conversation_id"] == "conv_1" + assert call_kwargs["app_id"] == "app_1" + assert call_kwargs["message_id"] == "msg_1" + assert call_kwargs["context"] is not None + + def test_invoke_with_credentials(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + # Patch PluginInvokeContext to bypass pydantic validation + mock_context = MagicMock() + mocker.patch( + "core.agent.strategy.plugin.PluginInvokeContext", + return_value=mock_context, + ) + + credentials = MagicMock() + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + credentials=credentials, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + @pytest.mark.parametrize( + ("conversation_id", "app_id", "message_id"), + [ + (None, None, None), + ("conv", None, None), + (None, "app", None), + (None, None, "msg"), + ], + ) + def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + def test_invoke_convert_raises_exception(self, strategy, mocker) -> None: + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=MagicMock(), + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + side_effect=ValueError("conversion failed"), + ) + + with pytest.raises(ValueError): + list(strategy._invoke(params={}, user_id="user_1")) + + def test_invoke_manager_raises_exception(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke.side_effect = RuntimeError("invoke failed") + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + with pytest.raises(RuntimeError): + list(strategy._invoke(params={}, user_id="user_1")) diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py new file mode 100644 index 0000000000..683cc0e36f --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -0,0 +1,802 @@ +import json +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +import core.agent.base_agent_runner as module +from core.agent.base_agent_runner import BaseAgentRunner + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture +def mock_db_session(mocker): + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + return session + + +@pytest.fixture +def runner(mocker, mock_db_session): + r = BaseAgentRunner.__new__(BaseAgentRunner) + r.tenant_id = "tenant" + r.user_id = "user" + r.agent_thought_count = 0 + r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1") + r.app_config = mocker.MagicMock() + r.app_config.app_id = "app1" + r.app_config.agent = None + r.dataset_tools = [] + r.application_generate_entity = mocker.MagicMock(invoke_from="test") + r._current_thoughts = [] + return r + + +# ========================================================== +# _repack_app_generate_entity +# ========================================================== + + +class TestRepack: + def test_sets_empty_if_none(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = None + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "" + + def test_keeps_existing(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = "abc" + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "abc" + + +# ========================================================== +# update_prompt_message_tool +# ========================================================== + + +class TestUpdatePromptTool: + def build_param(self, mocker, **kwargs): + p = mocker.MagicMock() + p.form = kwargs.get("form") + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + p.type = mock_type + + p.name = kwargs.get("name", "p1") + p.llm_description = "desc" + p.input_schema = kwargs.get("input_schema") + p.options = kwargs.get("options") + p.required = kwargs.get("required", False) + return p + + def test_skip_non_llm(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form="NOT_LLM") + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_enum_and_required(self, runner, mocker): + option = mocker.MagicMock(value="opt1") + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + options=[option], + required=True, + ) + + tool = mocker.MagicMock() + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert "p1" in result.parameters["required"] + + def test_skip_file_type_param(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM) + param.type = module.ToolParameter.ToolParameterType.FILE + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_duplicate_required_not_duplicated(self, runner, mocker): + tool = mocker.MagicMock() + + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + required=True, + ) + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": ["p1"]} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["required"].count("p1") == 1 + + +# ========================================================== +# create_agent_thought +# ========================================================== + + +class TestCreateAgentThought: + def test_with_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=10) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"]) + assert result == "10" + assert runner.agent_thought_count == 1 + + def test_without_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=11) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", []) + assert result == "11" + + +# ========================================================== +# save_agent_thought +# ========================================================== + + +class TestSaveAgentThought: + def setup_agent(self, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;tool2" + agent.tool_labels = {} + agent.thought = "" + return agent + + def test_not_found(self, runner, mock_db_session): + mock_db_session.scalar.return_value = None + with pytest.raises(ValueError): + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + def test_full_update(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + mock_label = mocker.MagicMock() + mock_label.to_dict.return_value = {"en_US": "label"} + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label) + + usage = mocker.MagicMock( + prompt_tokens=1, + prompt_price_unit=Decimal("0.1"), + prompt_unit_price=Decimal("0.1"), + completion_tokens=2, + completion_price_unit=Decimal("0.2"), + completion_unit_price=Decimal("0.2"), + total_tokens=3, + total_price=Decimal("0.3"), + ) + + runner.save_agent_thought( + "id", + "tool1;tool2", + {"a": 1}, + "thought", + {"b": 2}, + {"meta": 1}, + "answer", + ["f1"], + usage, + ) + + assert agent.answer == "answer" + assert agent.tokens == 3 + assert "tool1" in json.loads(agent.tool_labels_str) + + def test_label_fallback_when_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + agent.tool = "unknown_tool" + mock_db_session.scalar.return_value = agent + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert "unknown_tool" in labels + + def test_json_failure_paths(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + bad_obj = MagicMock() + bad_obj.__str__.return_value = "bad" + + runner.save_agent_thought( + "id", + None, + bad_obj, + None, + bad_obj, + bad_obj, + None, + [], + None, + ) + + assert mock_db_session.commit.called + + def test_messages_ids_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + runner.save_agent_thought("id", None, None, None, None, None, None, None, None) + assert mock_db_session.commit.called + + def test_success_dict_serialization(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought( + "id", + None, + {"a": 1}, + None, + {"b": 2}, + None, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + + +# ========================================================== +# organize_agent_user_prompt +# ========================================================== + + +class TestOrganizeUserPrompt: + def test_no_files(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_with_files_no_config(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_image_detail_low_fallback(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[]) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + +# ========================================================== +# organize_agent_history +# ========================================================== + + +class TestOrganizeHistory: + def test_empty(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_answer_only(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert any(isinstance(x, module.AssistantPromptMessage) for x in result) + + def test_skip_current_message(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input="invalid", + observation="invalid", + thought="thinking", + ) + msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_empty_tool_name_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=";", thought="thinking") + msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_valid_json_tool_flow(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=json.dumps({"tool1": {"x": 1}}), + observation=json.dumps({"tool1": "obs"}), + thought="thinking", + ) + + msg = mocker.MagicMock( + id="m100", + agent_thoughts=[thought], + answer=None, + app_model_config=None, + ) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + +# ========================================================== +# _convert_tool_to_prompt_message_tool (new coverage) +# ========================================================== + + +class TestConvertToolToPromptMessageTool: + def test_basic_conversion(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + runtime_param = mocker.MagicMock() + runtime_param.form = module.ToolParameter.ToolParameterForm.LLM + runtime_param.name = "param1" + runtime_param.llm_description = "desc" + runtime_param.required = True + runtime_param.input_schema = None + runtime_param.options = None + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + runtime_param.type = mock_type + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [runtime_param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + assert entity == tool_entity + + def test_full_conversion_multiple_params(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + # LLM param with input_schema override + param1 = mocker.MagicMock() + param1.form = module.ToolParameter.ToolParameterForm.LLM + param1.name = "p1" + param1.llm_description = "desc" + param1.required = True + param1.input_schema = {"type": "integer"} + param1.options = None + param1.type = mocker.MagicMock() + + # SYSTEM_FILES param should be skipped + param2 = mocker.MagicMock() + param2.form = module.ToolParameter.ToolParameterForm.LLM + param2.name = "file_param" + param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param1, param2] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + + assert entity == tool_entity + + +# ========================================================== +# _init_prompt_tools additional branches +# ========================================================== + + +class TestInitPromptToolsExtended: + def test_agent_tool_branch(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="agent_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity")) + + tools, prompts = runner._init_prompt_tools() + assert "agent_tool" in tools + + def test_exception_in_conversion(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="bad_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception) + + tools, prompts = runner._init_prompt_tools() + assert tools == {} + + +# ========================================================== +# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS) +# ========================================================== + + +class TestAdditionalCoverage: + def test_update_prompt_with_input_schema(self, runner, mocker): + tool = mocker.MagicMock() + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "p1" + param.required = False + param.llm_description = "desc" + param.options = None + param.input_schema = {"type": "number"} + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + param.type = mock_type + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"]["p1"]["type"] == "number" + + def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {"tool1": {"en_US": "existing"}} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert labels["tool1"]["en_US"] == "existing" + + def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None) + assert agent.tool_meta_str == "meta_string" + + def test_convert_dataset_retriever_tool(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + assert prompt is not None + + def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"]) + mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock()) + + mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw)) + mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw)) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result is not None + + def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=None, thought="thinking") + msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1;tool2", + tool_input=json.dumps({"tool1": {}, "tool2": {}}), + observation=json.dumps({"tool1": "o1", "tool2": "o2"}), + thought="thinking", + ) + msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + # ================= Additional Surgical Coverage ================= + + def test_convert_tool_select_enum_branch(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = True + param.llm_description = "desc" + param.input_schema = None + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + assert prompt_tool is not None + + +class TestConvertDatasetRetrieverTool: + def test_required_param_added(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + + assert prompt is not None + + +class TestBaseAgentRunnerInit: + def test_init_sets_stream_tool_call_and_files(self, mocker): + session = mocker.MagicMock() + session.query.return_value.where.return_value.count.return_value = 2 + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) + mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"]) + + llm = mocker.MagicMock() + llm.get_model_schema.return_value = mocker.MagicMock( + features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION] + ) + model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c") + + app_config = mocker.MagicMock() + app_config.app_id = "app1" + app_config.agent = None + app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"}) + app_config.additional_features = mocker.MagicMock(show_retrieve_source=True) + + app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"]) + message = mocker.MagicMock(id="msg1", conversation_id="conv1") + + runner = BaseAgentRunner( + tenant_id="tenant", + application_generate_entity=app_generate, + conversation=mocker.MagicMock(), + app_config=app_config, + model_config=mocker.MagicMock(), + config=mocker.MagicMock(), + queue_manager=mocker.MagicMock(), + message=message, + user_id="user", + model_instance=model_instance, + ) + + assert runner.stream_tool_call is True + assert runner.files == ["file1"] + assert runner.dataset_tools == ["ds_tool"] + assert runner.agent_thought_count == 2 + + +class TestBaseAgentRunnerCoverage: + def test_convert_tool_skips_non_llm_param(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = "NOT_LLM" + param.type = mocker.MagicMock() + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + + assert prompt_tool.parameters["properties"] == {} + + def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker): + dataset_tool = mocker.MagicMock() + dataset_tool.entity.identity.name = "ds" + runner.dataset_tools = [dataset_tool] + + mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock()) + + tools, prompt_tools = runner._init_prompt_tools() + + assert tools["ds"] == dataset_tool + assert len(prompt_tools) == 1 + + def test_update_prompt_message_tool_select_enum(self, runner, mocker): + tool = mocker.MagicMock() + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = False + param.llm_description = "desc" + param.input_schema = None + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"] + + def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + tool_input = {"a": 1} + observation = {"b": 2} + tool_meta = {"c": 3} + + real_dumps = json.dumps + + def dumps_side_effect(value, *args, **kwargs): + if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False: + raise TypeError("fail") + return real_dumps(value, *args, **kwargs) + + mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect) + + runner.save_agent_thought( + "id", + "tool1", + tool_input, + None, + observation, + tool_meta, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + assert isinstance(agent.tool_meta_str, str) + + def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;;" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + labels = json.loads(agent.tool_labels_str) + assert "" not in labels + + def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + + system_message = module.SystemPromptMessage(content="sys") + + result = runner.organize_agent_history([system_message]) + + assert system_message in result + + def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=None, + observation=None, + thought="thinking", + ) + msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + mocker.patch.object( + runner, + "organize_agent_user_prompt", + return_value=module.UserPromptMessage(content="user"), + ) + + result = runner.organize_agent_history([]) + + assert any(isinstance(item, module.ToolPromptMessage) for item in result) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py new file mode 100644 index 0000000000..9518c61202 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -0,0 +1,551 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentScratchpadUnit +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.agent.exc import AgentMaxIterationError + + +class DummyRunner(CotAgentRunner): + """Concrete implementation for testing abstract methods.""" + + def __init__(self, **kwargs): + # Completely bypass BaseAgentRunner __init__ to avoid DB/session usage + for k, v in kwargs.items(): + setattr(self, k, v) + # Minimal required defaults + self.history_prompt_messages = [] + self.memory = None + + def _organize_prompt_messages(self): + return [] + + +@pytest.fixture +def runner(mocker): + # Prevent BaseAgentRunner __init__ from hitting database + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history", + return_value=[], + ) + # Prepare required constructor dependencies for BaseAgentRunner + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock() + application_generate_entity.model_conf.stop = [] + application_generate_entity.model_conf.provider = "openai" + application_generate_entity.model_conf.parameters = {} + application_generate_entity.trace_manager = None + application_generate_entity.invoke_from = "test" + + app_config = MagicMock() + app_config.agent = MagicMock() + app_config.agent.max_iteration = 1 + app_config.prompt_template.simple_prompt_template = "Hello {{name}}" + + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + model_instance.invoke_llm.return_value = [] + + model_config = MagicMock() + model_config.model = "test-model" + + queue_manager = MagicMock() + message = MagicMock() + + runner = DummyRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=MagicMock(), + app_config=app_config, + model_config=model_config, + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Patch internal methods to isolate behavior + runner._repack_app_generate_entity = MagicMock() + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.recalc_llm_max_tokens = MagicMock() + runner.create_agent_thought = MagicMock(return_value="thought-id") + runner.save_agent_thought = MagicMock() + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + runner.memory = None + runner.history_prompt_messages = [] + + return runner + + +class TestFillInputs: + @pytest.mark.parametrize( + ("instruction", "inputs", "expected"), + [ + ("Hello {{name}}", {"name": "John"}, "Hello John"), + ("No placeholders", {"name": "John"}, "No placeholders"), + ("{{a}}{{b}}", {"a": 1, "b": 2}, "12"), + ("{{x}}", {"x": None}, "None"), + ("", {"x": "y"}, ""), + ], + ) + def test_fill_in_inputs(self, runner, instruction, inputs, expected): + result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs) + assert result == expected + + +class TestConvertDictToAction: + def test_convert_valid_dict(self, runner): + action_dict = {"action": "test", "action_input": {"a": 1}} + action = runner._convert_dict_to_action(action_dict) + assert action.action_name == "test" + assert action.action_input == {"a": 1} + + def test_convert_missing_keys(self, runner): + with pytest.raises(KeyError): + runner._convert_dict_to_action({"invalid": 1}) + + +class TestFormatAssistantMessage: + def test_format_assistant_message_multiple_scratchpads(self, runner): + sp1 = AgentScratchpadUnit( + agent_response="resp1", + thought="thought1", + action_str="action1", + action=AgentScratchpadUnit.Action(action_name="tool", action_input={}), + observation="obs1", + ) + sp2 = AgentScratchpadUnit( + agent_response="final", + thought="", + action_str="", + action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"), + observation=None, + ) + result = runner._format_assistant_message([sp1, sp2]) + assert "Final Answer:" in result + + def test_format_with_final(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="Done", + thought="", + action_str="", + action=None, + observation=None, + ) + # Simulate final state via action name + scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done") + result = runner._format_assistant_message([scratchpad]) + assert "Final Answer" in result + + def test_format_with_action_and_observation(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="resp", + thought="thinking", + action_str="action", + action=None, + observation="obs", + ) + # Non-final state: provide a non-final action + scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + result = runner._format_assistant_message([scratchpad]) + assert "Thought:" in result + assert "Action:" in result + assert "Observation:" in result + + +class TestHandleInvokeAction: + def test_handle_invoke_action_tool_not_present(self, runner): + action = AgentScratchpadUnit.Action(action_name="missing", action_input={}) + response, meta = runner._handle_invoke_action(action, {}, []) + assert "there is not a tool named" in response + + def test_tool_with_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1})) + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("result", [], MagicMock(to_dict=lambda: {})), + ) + + response, meta = runner._handle_invoke_action(action, tool_instances, []) + assert response == "result" + + +class TestOrganizeHistoricPromptMessages: + def test_empty_history(self, runner, mocker): + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt", + return_value=[], + ) + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRun: + def test_run_handles_empty_parser_output(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert isinstance(results, list) + + def test_run_with_action_and_tool_invocation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_respects_max_iteration_boundary(self, runner, mocker): + runner.app_config.agent.max_iteration = 1 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_basic_flow(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {"name": "John"})) + assert results + + def test_run_max_iteration_error(self, runner, mocker): + runner.app_config.agent.max_iteration = 0 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {})) + + def test_run_increase_usage_aggregation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + runner.app_config.agent.max_iteration = 2 + + usage_1 = LLMUsage.empty_usage() + usage_1.prompt_tokens = 1 + usage_1.completion_tokens = 1 + usage_1.total_tokens = 2 + usage_1.prompt_price = 1 + usage_1.completion_price = 1 + usage_1.total_price = 2 + + usage_2 = LLMUsage.empty_usage() + usage_2.prompt_tokens = 1 + usage_2.completion_tokens = 1 + usage_2.total_tokens = 2 + usage_2.prompt_price = 1 + usage_2.completion_price = 1 + usage_2.total_price = 2 + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + handle_output = mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[ + [action], + [], + ], + ) + + def _handle_side_effect(chunks, usage_dict): + call_index = handle_output.call_count + usage_dict["usage"] = usage_1 if call_index == 1 else usage_2 + return [action] if call_index == 1 else [] + + handle_output.side_effect = _handle_side_effect + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + results = list(runner.run(message, "query", {})) + final_usage = results[-1].delta.usage + assert final_usage is not None + assert final_usage.prompt_tokens == 2 + assert final_usage.completion_tokens == 2 + assert final_usage.total_tokens == 4 + assert final_usage.prompt_price == 2 + assert final_usage.completion_price == 2 + assert final_usage.total_price == 4 + + def test_run_when_no_action_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "" + + def test_run_usage_missing_key_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + + list(runner.run(message, "query", {})) + + def test_run_prompt_tool_update_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + # First iteration → action + # Second iteration → no action (empty list) + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[[action], []], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.app_config.agent.max_iteration = 5 + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + + list(runner.run(message, "query", {})) + + runner.update_prompt_message_tool.assert_called_once() + + def test_historic_with_assistant_and_tool_calls(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + + assistant = AssistantPromptMessage(content="thinking") + assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] + + tool_msg = ToolPromptMessage(content="obs", tool_call_id="1") + + runner.history_prompt_messages = [assistant, tool_msg] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + def test_historic_final_flush_branch(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + assistant = AssistantPromptMessage(content="final") + runner.history_prompt_messages = [assistant] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + +class TestInitReactState: + def test_init_react_state_resets_state(self, runner, mocker): + mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"]) + runner._agent_scratchpad = ["old"] + runner._query = "old" + + runner._init_react_state("new-query") + + assert runner._query == "new-query" + assert runner._agent_scratchpad == [] + assert runner._historic_prompt_messages == ["historic"] + + +class TestHandleInvokeActionExtended: + def test_tool_with_invalid_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json") + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})), + ) + + message_file_ids = [] + response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids) + + assert response == "ok" + assert message_file_ids == ["file1"] + runner.queue_manager.publish.assert_called() + + +class TestFillInputsEdgeCases: + def test_fill_inputs_with_empty_inputs(self, runner): + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {}) + assert result == "Hello {{x}}" + + def test_fill_inputs_with_exception_in_replace(self, runner): + class BadValue: + def __str__(self): + raise Exception("fail") + + # Should silently continue on exception + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()}) + assert result == "Hello {{x}}" + + +class TestOrganizeHistoricPromptMessagesExtended: + def test_user_message_flushes_scratchpad(self, runner, mocker): + from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + user_message = UserPromptMessage(content="Hi") + + runner.history_prompt_messages = [user_message] + + mock_transform = mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + ) + mock_transform.return_value.get_prompt.return_value = ["final"] + + result = runner._organize_historic_prompt_messages([]) + assert result == ["final"] + + def test_tool_message_without_scratchpad_raises(self, runner): + from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + + runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] + + with pytest.raises(NotImplementedError): + runner._organize_historic_prompt_messages([]) + + def test_agent_history_transform_invocation(self, runner, mocker): + mock_transform = MagicMock() + mock_transform.get_prompt.return_value = [] + + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + return_value=mock_transform, + ) + + runner.history_prompt_messages = [] + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRunAdditionalBranches: + def test_run_with_no_action_final_answer_empty(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=["thinking"], + ) + + results = list(runner.run(message, "query", {})) + assert any(hasattr(r, "delta") for r in results) + + def test_run_with_final_answer_action_string(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "done" + + def test_run_with_final_answer_action_dict(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert json.loads(results[-1].delta.message.content) == {"a": 1} + + def test_run_with_string_final_answer(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + # Remove invalid branch: Pydantic enforces str|dict for action_input + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "12345" diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py new file mode 100644 index 0000000000..f9d69d1196 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -0,0 +1,215 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from tests.unit_tests.core.agent.conftest import ( + DummyAgentConfig, + DummyAppConfig, + DummyTool, +) +from tests.unit_tests.core.agent.conftest import ( + DummyPromptEntity as DummyPrompt, +) + + +class DummyFileUploadConfig: + def __init__(self, image_config=None): + self.image_config = image_config + + +class DummyImageConfig: + def __init__(self, detail=None): + self.detail = detail + + +class DummyGenerateEntity: + def __init__(self, file_upload_config=None): + self.file_upload_config = file_upload_config + + +class DummyUnit: + def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def runner(): + runner = CotChatAgentRunner.__new__(CotChatAgentRunner) + runner._instruction = "test_instruction" + runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")] + runner._query = "user query" + runner._agent_scratchpad = [] + runner.files = [] + runner.application_generate_entity = DummyGenerateEntity() + runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"]) + return runner + + +class TestOrganizeSystemPrompt: + def test_organize_system_prompt_success(self, runner, mocker): + first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}" + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt))) + + mocker.patch( + "core.agent.cot_chat_agent_runner.jsonable_encoder", + return_value=[{"name": "tool1"}, {"name": "tool2"}], + ) + + result = runner._organize_system_prompt() + + assert "test_instruction" in result.content + assert "tool1" in result.content + assert "tool2" in result.content + assert "tool1, tool2" in result.content + + def test_organize_system_prompt_missing_agent(self, runner): + runner.app_config = DummyAppConfig(agent=None) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + def test_organize_system_prompt_missing_prompt(self, runner): + runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None)) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + +class TestOrganizeUserQuery: + @pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")]) + def test_organize_user_query_no_files(self, runner, files): + runner.files = files + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert result[0].content == "query" + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.LOW, + ) + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + + image_config = DummyImageConfig(detail="high") + runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config)) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner): + mock_to_prompt.return_value = TextPromptMessageContent(data="file_content") + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + + +class TestOrganizePromptMessages: + def test_no_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + result = runner._organize_prompt_messages() + assert "system" in result + assert "query" in result + runner._organize_historic_prompt_messages.assert_called_once() + + def test_with_final_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit(final=True, agent_response="done") + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Final Answer: done" in combined + + def test_with_thought_action_observation(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit( + final=False, + thought="thinking", + action_str="action", + observation="observe", + ) + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: thinking" in combined + assert "Action: action" in combined + assert "Observation: observe" in combined + + def test_multiple_units_mixed(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + units = [ + DummyUnit(final=False, thought="t1"), + DummyUnit(final=True, agent_response="done"), + ] + runner._agent_scratchpad = units + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: t1" in combined + assert "Final Answer: done" in combined diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py new file mode 100644 index 0000000000..ab822bb57d --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -0,0 +1,234 @@ +import json + +import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def runner(mocker, dummy_tool_factory): + runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner) + + runner._instruction = "Test instruction" + runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")] + runner._query = "What is Python?" + runner._agent_scratchpad = [] + + mocker.patch( + "core.agent.cot_completion_agent_runner.jsonable_encoder", + side_effect=lambda tools: [{"name": t.name} for t in tools], + ) + + return runner + + +# ====================================================== +# _organize_instruction_prompt Tests +# ====================================================== + + +class TestOrganizeInstructionPrompt: + def test_success_all_placeholders( + self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = ( + "{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}" + ) + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + result = runner._organize_instruction_prompt() + + assert "Test instruction" in result + assert "toolA" in result + assert "toolB" in result + tools_payload = json.loads(result.split(" | ")[1]) + assert {item["name"] for item in tools_payload} == {"toolA", "toolB"} + + def test_agent_none_raises(self, runner, dummy_app_config_factory): + runner.app_config = dummy_app_config_factory(agent=None) + with pytest.raises(ValueError, match="Agent configuration is not set"): + runner._organize_instruction_prompt() + + def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory): + runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None)) + with pytest.raises(ValueError, match="prompt entity is not set"): + runner._organize_instruction_prompt() + + +# ====================================================== +# _organize_historic_prompt Tests +# ====================================================== + + +class TestOrganizeHistoricPrompt: + def test_with_user_and_assistant_string(self, runner, mocker): + user_msg = UserPromptMessage(content="Hello") + assistant_msg = AssistantPromptMessage(content="Hi there") + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[user_msg, assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Question: Hello" in result + assert "Hi there" in result + + def test_assistant_list_with_text_content(self, runner, mocker): + text_content = TextPromptMessageContent(data="Partial answer") + assistant_msg = AssistantPromptMessage(content=[text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Partial answer" in result + + def test_assistant_list_with_non_text_content_ignored(self, runner, mocker): + non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png") + assistant_msg = AssistantPromptMessage(content=[non_text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + def test_empty_history(self, runner, mocker): + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + +# ====================================================== +# _organize_prompt_messages Tests +# ====================================================== + + +class TestOrganizePromptMessages: + def test_full_flow_with_scratchpad( + self, + runner, + mocker, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"), + dummy_scratchpad_unit_factory(final=True, agent_response="Done"), + ] + + result = runner._organize_prompt_messages() + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + + content = result[0].content + + assert "History" in content + assert "Thought: Thinking" in content + assert "Action: Act" in content + assert "Observation: Obs" in content + assert "Final Answer: Done" in content + assert "Question: What is Python?" in content + + def test_no_scratchpad( + self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = None + + result = runner._organize_prompt_messages() + + assert "Question: What is Python?" in result[0].content + + @pytest.mark.parametrize( + ("thought", "action", "observation"), + [ + ("T", None, None), + ("T", "A", None), + ("T", None, "O"), + ], + ) + def test_partial_scratchpad_units( + self, + runner, + mocker, + thought, + action, + observation, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory( + final=False, + thought=thought, + action_str=action, + observation=observation, + ) + ] + + result = runner._organize_prompt_messages() + content = result[0].content + + assert "Thought:" in content + if action: + assert "Action:" in content + if observation: + assert "Observation:" in content diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py new file mode 100644 index 0000000000..8843a8d505 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -0,0 +1,452 @@ +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + DocumentPromptMessageContent, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.nodes.agent.exc import AgentMaxIterationError + +# ============================== +# Dummy Helper Classes +# ============================== + + +def build_usage(pt=1, ct=1, tt=2) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.prompt_tokens = pt + usage.completion_tokens = ct + usage.total_tokens = tt + usage.prompt_price = 0 + usage.completion_price = 0 + usage.total_price = 0 + return usage + + +class DummyMessage: + def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None): + self.content: str | None = content + self.tool_calls: list[Any] = tool_calls or [] + + +class DummyDelta: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + + +class DummyChunk: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.delta: DummyDelta = DummyDelta(message=message, usage=usage) + + +class DummyResult: + def __init__( + self, + message: DummyMessage | None = None, + usage: LLMUsage | None = None, + prompt_messages: list[DummyMessage] | None = None, + ): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + self.prompt_messages: list[DummyMessage] = prompt_messages or [] + self.system_fingerprint: str = "" + + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def runner(mocker): + # Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.__init__", + return_value=None, + ) + + # Patch streaming chunk models to avoid validation on dummy message objects + mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock) + mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock) + + app_config = MagicMock() + app_config.agent = MagicMock(max_iteration=2) + app_config.prompt_template = MagicMock(simple_prompt_template="system") + + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock(parameters={}, stop=None) + application_generate_entity.trace_manager = MagicMock() + application_generate_entity.invoke_from = "test" + application_generate_entity.app_config = MagicMock(app_id="app") + application_generate_entity.file_upload_config = None + + queue_manager = MagicMock() + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + + message = MagicMock(id="msg1") + conversation = MagicMock(id="conv1") + + runner = FunctionCallAgentRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=conversation, + app_config=app_config, + model_config=MagicMock(), + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Manually inject required attributes normally set by BaseAgentRunner + runner.tenant_id = "tenant" + runner.application_generate_entity = application_generate_entity + runner.conversation = conversation + runner.app_config = app_config + runner.model_config = MagicMock() + runner.config = MagicMock() + runner.queue_manager = queue_manager + runner.message = message + runner.user_id = "user" + runner.model_instance = model_instance + + runner.stream_tool_call = False + runner.memory = None + runner.history_prompt_messages = [] + runner._current_thoughts = [] + runner.files = [] + runner.agent_callback = MagicMock() + + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.create_agent_thought = MagicMock(return_value="thought1") + runner.save_agent_thought = MagicMock() + runner.recalc_llm_max_tokens = MagicMock() + runner.update_prompt_message_tool = MagicMock() + + return runner + + +# ============================== +# Tool Call Checks +# ============================== + + +class TestToolCallChecks: + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_tool_calls(self, runner, tool_calls, expected): + chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_tool_calls(chunk) is expected + + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_blocking_tool_calls(self, runner, tool_calls, expected): + result = DummyResult(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_blocking_tool_calls(result) is expected + + +# ============================== +# Extract Tool Calls +# ============================== + + +class TestExtractToolCalls: + def test_extract_tool_calls_with_valid_json(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {"a": 1})] + + def test_extract_tool_calls_empty_arguments(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "" + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {})] + + def test_extract_blocking_tool_calls(self, runner): + tool_call = MagicMock() + tool_call.id = "2" + tool_call.function.name = "block" + tool_call.function.arguments = json.dumps({"x": 2}) + + result = DummyResult(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_blocking_tool_calls(result) + + assert calls == [("2", "block", {"x": 2})] + + +# ============================== +# System Message Initialization +# ============================== + + +class TestInitSystemMessage: + def test_init_system_message_empty_prompt_messages(self, runner): + result = runner._init_system_message("system", []) + assert len(result) == 1 + + def test_init_system_message_insert_at_start(self, runner): + msgs = [MagicMock()] + result = runner._init_system_message("system", msgs) + assert result[0].content == "system" + + def test_init_system_message_no_template(self, runner): + result = runner._init_system_message("", []) + assert result == [] + + +# ============================== +# Organize User Query +# ============================== + + +class TestOrganizeUserQuery: + def test_without_files(self, runner): + result = runner._organize_user_query("query", []) + assert len(result) == 1 + + def test_with_none_query(self, runner): + result = runner._organize_user_query(None, []) + assert len(result) == 1 + + def test_with_files_uses_image_detail_config(self, runner, mocker): + file_content = TextPromptMessageContent(data="file-content") + mock_to_prompt = mocker.patch( + "core.agent.fc_agent_runner.file_manager.to_prompt_message_content", + return_value=file_content, + ) + + image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH) + runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config) + runner.files = ["file1"] + + result = runner._organize_user_query("query", []) + + assert len(result) == 1 + assert isinstance(result[0].content, list) + mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) + + +# ============================== +# Clear User Prompt Images +# ============================== + + +class TestClearUserPromptImageMessages: + def test_clear_text_and_image_content(self, runner): + text = MagicMock() + text.type = "text" + text.data = "hello" + + image = MagicMock() + image.type = "image" + image.data = "img" + + user_msg = MagicMock() + user_msg.__class__.__name__ = "UserPromptMessage" + user_msg.content = [text, image] + + result = runner._clear_user_prompt_image_messages([user_msg]) + assert isinstance(result, list) + + def test_clear_includes_file_placeholder(self, runner): + text = TextPromptMessageContent(data="hello") + image = ImagePromptMessageContent(format="url", mime_type="image/png") + document = DocumentPromptMessageContent(format="url", mime_type="application/pdf") + + user_msg = UserPromptMessage(content=[text, image, document]) + + result = runner._clear_user_prompt_image_messages([user_msg]) + + assert result[0].content == "hello\n[image]\n[file]" + + +# ============================== +# Run Method Tests +# ============================== + + +class TestRunMethod: + def test_run_non_streaming_no_tool_calls(self, runner): + message = MagicMock(id="m1") + dummy_message = DummyMessage(content="hello") + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + runner.queue_manager.publish.assert_called() + + queue_calls = runner.queue_manager.publish.call_args_list + assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls) + + def test_run_streaming_branch(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_streaming_tool_calls_list_content(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [generator(), final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_non_streaming_list_content(self, runner): + message = MagicMock(id="m1") + content = [TextPromptMessageContent(data="hi")] + dummy_message = DummyMessage(content=content) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi" + + def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + real_dumps = json.dumps + + def flaky_dumps(obj, *args, **kwargs): + if kwargs.get("ensure_ascii") is False: + return real_dumps(obj, *args, **kwargs) + raise TypeError("boom") + + mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps) + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_with_missing_tool_instance(self, runner): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "missing" + tool_call.function.arguments = json.dumps({}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_with_tool_instance_and_files(self, runner, mocker): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + tool_instance = MagicMock() + prompt_tool = MagicMock() + prompt_tool.name = "tool" + runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool]) + + tool_invoke_meta = MagicMock() + tool_invoke_meta.to_dict.return_value = {"ok": True} + mocker.patch( + "core.agent.fc_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], tool_invoke_meta), + ) + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + assert any( + isinstance(call.args[0], QueueMessageFileEvent) + and call.args[0].message_file_id == "file1" + and call.args[1] == PublishFrom.APPLICATION_MANAGER + for call in runner.queue_manager.publish.call_args_list + ) + + def test_run_max_iteration_error(self, runner): + runner.app_config.agent.max_iteration = 0 + + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "{}" + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query")) diff --git a/api/tests/unit_tests/core/agent/test_plugin_entities.py b/api/tests/unit_tests/core/agent/test_plugin_entities.py new file mode 100644 index 0000000000..9955190aca --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_plugin_entities.py @@ -0,0 +1,324 @@ +"""Unit tests for core.agent.plugin_entities. + +Covers entities such as AgentFeature, AgentProviderEntityWithPlugin, +AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter, +AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on +Pydantic ValidationError behavior and pytest fixtures for validation and +mocking; ensure entity invariants and validation rules remain stable. +""" + +import pytest +from pydantic import ValidationError + +from core.agent.plugin_entities import ( + AgentFeature, + AgentProviderEntityWithPlugin, + AgentStrategyEntity, + AgentStrategyIdentity, + AgentStrategyParameter, + AgentStrategyProviderEntity, + AgentStrategyProviderIdentity, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity + +# ========================================================= +# Fixtures +# ========================================================= + + +@pytest.fixture +def mock_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyIdentity) + + +@pytest.fixture +def mock_provider_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyProviderIdentity) + + +# ========================================================= +# AgentStrategyParameterType Tests +# ========================================================= + + +class TestAgentStrategyParameterType: + @pytest.mark.parametrize( + "enum_member", + list(AgentStrategyParameter.AgentStrategyParameterType), + ) + def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.as_normal_type", + return_value="normalized", + ) + + result = enum_member.as_normal_type() + + mock_func.assert_called_once_with(enum_member) + assert result == "normalized" + + def test_as_normal_type_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.as_normal_type", + side_effect=RuntimeError("boom"), + ) + + with pytest.raises(RuntimeError): + enum_member.as_normal_type() + + @pytest.mark.parametrize( + ("enum_member", "value"), + [ + (AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"), + (AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10), + (AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True), + (AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}), + (AgentStrategyParameter.AgentStrategyParameterType.STRING, None), + (AgentStrategyParameter.AgentStrategyParameterType.FILES, []), + ], + ) + def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + return_value="casted", + ) + + result = enum_member.cast_value(value) + + mock_func.assert_called_once_with(enum_member, value) + assert result == "casted" + + def test_cast_value_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + side_effect=ValueError("invalid"), + ) + + with pytest.raises(ValueError): + enum_member.cast_value("bad") + + +# ========================================================= +# AgentStrategyParameter Tests +# ========================================================= + + +class TestAgentStrategyParameter: + def test_valid_creation_minimal(self) -> None: + # bypass base PluginParameter required fields using model_construct + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=None, + ) + assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING + assert param.help is None + + def test_valid_creation_with_help(self) -> None: + help_obj = I18nObject(en_US="test") + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=help_obj, + ) + assert param.help == help_obj + + @pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}]) + def test_invalid_type_raises_validation_error(self, invalid_type) -> None: + with pytest.raises(ValidationError) as exc_info: + AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y")) + + assert any(error["loc"] == ("type",) for error in exc_info.value.errors()) + + def test_init_frontend_parameter_calls_external(self, mocker) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + return_value="frontend", + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + result = param.init_frontend_parameter("value") + + mock_func.assert_called_once_with(param, param.type, "value") + assert result == "frontend" + + def test_init_frontend_parameter_propagates_exception(self, mocker) -> None: + mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + side_effect=RuntimeError("error"), + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + with pytest.raises(RuntimeError): + param.init_frontend_parameter("value") + + +# ========================================================= +# AgentStrategyProviderEntity Tests +# ========================================================= + + +class TestAgentStrategyProviderEntity: + def test_creation_with_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="plugin-123", + ) + assert entity.plugin_id == "plugin-123" + + def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="", + ) + assert entity.plugin_id == "" + + def test_creation_without_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity(identity=mock_provider_identity) + assert entity.plugin_id is None + + def test_invalid_identity_raises(self) -> None: + with pytest.raises(ValidationError): + AgentStrategyProviderEntity(identity="invalid") + + +# ========================================================= +# AgentStrategyEntity Tests +# ========================================================= + + +class TestAgentStrategyEntity: + def test_parameters_default_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + ) + assert entity.parameters == [] + + def test_parameters_none_converted_to_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=None, + ) + assert entity.parameters == [] + + def test_parameters_preserved(self, mock_identity) -> None: + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[param], + ) + assert entity.parameters == [param] + + def test_invalid_parameters_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters="invalid", + ) + + @pytest.mark.parametrize( + "features", + [ + None, + [], + [AgentFeature.HISTORY_MESSAGES], + ], + ) + def test_features_valid(self, mock_identity, features) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features=features, + ) + assert entity.features == features + + def test_invalid_features_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features="invalid", + ) + + def test_output_schema_and_meta_version(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + output_schema={"type": "object"}, + meta_version="v1", + ) + assert entity.output_schema == {"type": "object"} + assert entity.meta_version == "v1" + + def test_missing_required_fields_raise(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity(identity=mock_identity) + + +# ========================================================= +# AgentProviderEntityWithPlugin Tests +# ========================================================= + + +class TestAgentProviderEntityWithPlugin: + def test_default_strategies_empty(self, mock_provider_identity) -> None: + entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity) + assert entity.strategies == [] + + def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None: + strategy = AgentStrategyEntity.model_construct( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[], + ) + + entity = AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies=[strategy], + ) + assert entity.strategies == [strategy] + + def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None: + with pytest.raises(ValidationError): + AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies="invalid", + ) + + +# ========================================================= +# Inheritance Smoke Tests +# ========================================================= + + +class TestInheritanceBehavior: + def test_agent_strategy_identity_inherits(self) -> None: + assert issubclass(AgentStrategyIdentity, ToolIdentity) + + def test_agent_strategy_provider_identity_inherits(self) -> None: + assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)