test: unit test cases core.agent module (#32474)

This commit is contained in:
Rajat Agarwal 2026-03-12 08:40:15 +05:30 committed by GitHub
parent c59685748c
commit d5724aebde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 3350 additions and 61 deletions

View File

@ -0,0 +1,80 @@
import pytest
class DummyTool:
def __init__(self, name):
self.name = name
class DummyPromptEntity:
def __init__(self, first_prompt):
self.first_prompt = first_prompt
class DummyAgentConfig:
def __init__(self, prompt_entity=None):
self.prompt = prompt_entity
class DummyAppConfig:
def __init__(self, agent=None):
self.agent = agent
class DummyScratchpadUnit:
def __init__(
self,
final=False,
thought=None,
action_str=None,
observation=None,
agent_response=None,
):
self._final = final
self.thought = thought
self.action_str = action_str
self.observation = observation
self.agent_response = agent_response
def is_final(self):
return self._final
@pytest.fixture
def dummy_tool_factory():
def _factory(name):
return DummyTool(name)
return _factory
@pytest.fixture
def dummy_prompt_entity_factory():
def _factory(first_prompt):
return DummyPromptEntity(first_prompt)
return _factory
@pytest.fixture
def dummy_agent_config_factory():
def _factory(prompt_entity=None):
return DummyAgentConfig(prompt_entity)
return _factory
@pytest.fixture
def dummy_app_config_factory():
def _factory(agent=None):
return DummyAppConfig(agent)
return _factory
@pytest.fixture
def dummy_scratchpad_unit_factory():
def _factory(**kwargs):
return DummyScratchpadUnit(**kwargs)
return _factory

View File

@ -1,70 +1,255 @@
"""Unit tests for CotAgentOutputParser.
Verifies expected parsing behavior for streaming content and JSON payloads,
including edge cases such as empty/non-string content and malformed JSON.
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
model output structures. Implementation under test:
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
"""
import json
from collections.abc import Generator
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
for i in range(len(text)):
yield LLMResultChunk(
model="model",
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
@pytest.fixture
def mock_action_class(mocker):
mock_action = MagicMock()
mocker.patch(
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
mock_action,
)
return mock_action
@pytest.fixture
def usage_dict():
return {}
@pytest.fixture
def make_chunk():
def _make_chunk(content=None, usage=None):
delta = SimpleNamespace(
message=SimpleNamespace(content=content),
usage=usage,
)
return SimpleNamespace(delta=delta)
return _make_chunk
def test_cot_output_parser():
test_cases = [
{
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with json
{
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with JSON
{
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# list
{
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block
{
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block and json
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
]
# ============================================================
# Test Suite
# ============================================================
parser = CotAgentOutputParser()
usage_dict = {}
for test_case in test_cases:
# mock llm_response as a generator by text
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
results = parser.handle_react_stream_output(llm_response, usage_dict)
output = ""
for result in results:
if isinstance(result, str):
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]
class TestCotAgentOutputParser:
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
lightweight chunk/action doubles. Invariants: non-string/empty content
yields no output, usage gets recorded when provided, and valid action JSON
results in Action instantiation. Usage: invoke via pytest (e.g.,
`pytest -k TestCotAgentOutputParser`).
"""
# --------------------------------------------------------
# Basic streaming & usage
# --------------------------------------------------------
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("hello world")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "".join(result) == "hello world"
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk(None)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
def test_usage_update(self, make_chunk, usage_dict) -> None:
usage_data = {"tokens": 99}
chunks = [make_chunk("abc", usage=usage_data)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert usage_dict["usage"] == usage_data
# --------------------------------------------------------
# JSON parsing (direct + streaming)
# --------------------------------------------------------
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '{"action": "search", "input": "query"}'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '[{"action": "lookup", "input": "abc"}]'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
content = '{"foo": "bar"}'
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# Expect the serialized JSON to be yielded as a single element.
assert result == [json.dumps({"foo": "bar"})]
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
content = "{invalid json}"
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert any("invalid json" in str(r) for r in result)
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
chunks = [
make_chunk('{"action": '),
make_chunk('"multi", '),
make_chunk('"input": "step"}'),
]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk('{"foo": "bar"')]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
assert any('{"foo": "bar"' in item for item in result)
# --------------------------------------------------------
# Code block JSON extraction
# --------------------------------------------------------
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
content = """```json
{"action": "lookup", "input": "abc"}
```"""
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
# Multiple JSON objects inside single code fence (invalid combined JSON)
# Parser should safely ignore invalid combined block
content = """```json
{"action": "a1", "input": "x"}
{"action": "a2", "input": "y"}
```"""
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# No valid parsed action expected due to invalid combined JSON
assert mock_action_class.call_count == 0
assert isinstance(result, list)
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
content = """```json
{invalid}
```"""
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk('```json {"a":1}')]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
assert any('```json {"a":1}' in item for item in result)
# --------------------------------------------------------
# Action / Thought prefix handling
# --------------------------------------------------------
@pytest.mark.parametrize(
"content",
[
" action: something",
" ACTION: something",
" thought: reasoning",
" THOUGHT: reasoning",
],
)
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
joined = "".join(str(item) for item in result)
expected_word = "something" if "action:" in content.lower() else "reasoning"
assert expected_word in joined
assert "action:" not in joined.lower()
assert "thought:" not in joined.lower()
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("xaction: test")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "x" in "".join(map(str, result))
# --------------------------------------------------------
# Mixed streaming scenarios
# --------------------------------------------------------
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
content = 'start {"action": "mix", "input": "1"} end'
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# JSON action should be parsed
mock_action_class.assert_called_once()
# Ensure surrounding text is streamed (character-level)
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
assert "start" in joined
assert "end" in joined
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert mock_action_class.call_count == 2
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("text with ` random ` backticks")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "text with" in "".join(result)
# --------------------------------------------------------
# Boundary & edge inputs
# --------------------------------------------------------
@pytest.mark.parametrize(
"content",
[
"```",
"{",
"}",
"```json",
"action:",
"thought:",
" ",
],
)
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
joined = "".join(result)
if content == " ":
assert result == [] or joined == content
if content in {"```", "{", "}", "```json"}:
assert content in joined
if content.lower() in {"action:", "thought:"}:
assert "action:" not in joined.lower()
assert "thought:" not in joined.lower()

View File

@ -0,0 +1,174 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from core.agent.strategy.base import BaseAgentStrategy
class DummyStrategy(BaseAgentStrategy):
"""
Concrete implementation for testing BaseAgentStrategy
"""
def __init__(self, return_values=None, raise_exception=None):
self.return_values = return_values or []
self.raise_exception = raise_exception
self.received_args = None
def _invoke(
self,
params,
user_id,
conversation_id=None,
app_id=None,
message_id=None,
credentials=None,
) -> Generator:
self.received_args = (
params,
user_id,
conversation_id,
app_id,
message_id,
credentials,
)
if self.raise_exception:
raise self.raise_exception
yield from self.return_values
class TestBaseAgentStrategyInstantiation:
def test_cannot_instantiate_abstract_class(self) -> None:
with pytest.raises(TypeError):
BaseAgentStrategy()
class TestBaseAgentStrategyInvoke:
@pytest.fixture
def mock_message(self):
return MagicMock(name="AgentInvokeMessage")
@pytest.fixture
def mock_credentials(self):
return MagicMock(name="InvokeCredentials")
@pytest.mark.parametrize(
("params", "user_id", "conversation_id", "app_id", "message_id"),
[
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
({}, "user2", None, None, None),
({"a": 1}, "", "", "", ""),
({"nested": {"x": 1}}, "user3", None, "app3", None),
],
)
def test_invoke_success(
self,
mock_message,
mock_credentials,
params,
user_id,
conversation_id,
app_id,
message_id,
) -> None:
# Arrange
strategy = DummyStrategy(return_values=[mock_message])
# Act
result = list(
strategy.invoke(
params=params,
user_id=user_id,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
credentials=mock_credentials,
)
)
# Assert
assert result == [mock_message]
assert strategy.received_args == (
params,
user_id,
conversation_id,
app_id,
message_id,
mock_credentials,
)
def test_invoke_multiple_yields(self, mock_message) -> None:
# Arrange
messages = [mock_message, MagicMock(), MagicMock()]
strategy = DummyStrategy(return_values=messages)
# Act
result = list(strategy.invoke(params={}, user_id="user"))
# Assert
assert result == messages
def test_invoke_empty_generator(self) -> None:
# Arrange
strategy = DummyStrategy(return_values=[])
# Act
result = list(strategy.invoke(params={}, user_id="user"))
# Assert
assert result == []
def test_invoke_propagates_exception(self) -> None:
# Arrange
strategy = DummyStrategy(raise_exception=ValueError("failure"))
# Act & Assert
with pytest.raises(ValueError, match="failure"):
list(strategy.invoke(params={}, user_id="user"))
@pytest.mark.parametrize(
"invalid_params",
[
None,
"",
123,
[],
],
)
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
"""
Base class does not validate types ensure pass-through behavior
"""
strategy = DummyStrategy(return_values=[])
result = list(strategy.invoke(params=invalid_params, user_id="user"))
assert result == []
def test_invoke_none_user_id(self) -> None:
strategy = DummyStrategy(return_values=[])
result = list(strategy.invoke(params={}, user_id=None))
assert result == []
class TestBaseAgentStrategyGetParameters:
def test_get_parameters_default_empty_list(self) -> None:
strategy = DummyStrategy()
result = strategy.get_parameters()
assert isinstance(result, list)
assert result == []
def test_get_parameters_returns_new_list_each_time(self) -> None:
strategy = DummyStrategy()
first = strategy.get_parameters()
second = strategy.get_parameters()
assert first == second == []
assert first is not second

View File

@ -0,0 +1,272 @@
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
from unittest.mock import MagicMock
import pytest
from core.agent.strategy.plugin import PluginAgentStrategy
# ============================================================
# Fixtures
# ============================================================
@pytest.fixture
def mock_parameter():
def _factory(name="param", return_value="initialized"):
param = MagicMock()
param.name = name
param.init_frontend_parameter = MagicMock(return_value=return_value)
return param
return _factory
@pytest.fixture
def mock_declaration(mock_parameter):
param1 = mock_parameter("param1", "init1")
param2 = mock_parameter("param2", "init2")
identity = MagicMock()
identity.provider = "provider_x"
identity.name = "strategy_x"
declaration = MagicMock()
declaration.parameters = [param1, param2]
declaration.identity = identity
return declaration
@pytest.fixture
def strategy(mock_declaration):
return PluginAgentStrategy(
tenant_id="tenant_123",
declaration=mock_declaration,
meta_version="v1",
)
# ============================================================
# Initialization Tests
# ============================================================
class TestPluginAgentStrategyInitialization:
def test_init_sets_attributes(self, mock_declaration) -> None:
strategy = PluginAgentStrategy(
tenant_id="tenant_test",
declaration=mock_declaration,
meta_version="meta_v",
)
assert strategy.tenant_id == "tenant_test"
assert strategy.declaration == mock_declaration
assert strategy.meta_version == "meta_v"
def test_init_meta_version_none(self, mock_declaration) -> None:
strategy = PluginAgentStrategy(
tenant_id="tenant_test",
declaration=mock_declaration,
meta_version=None,
)
assert strategy.meta_version is None
# ============================================================
# get_parameters Tests
# ============================================================
class TestGetParameters:
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
result = strategy.get_parameters()
assert result == mock_declaration.parameters
# ============================================================
# initialize_parameters Tests
# ============================================================
class TestInitializeParameters:
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
params = {"param1": "value1"}
result = strategy.initialize_parameters(params.copy())
assert result["param1"] == "init1"
assert result["param2"] == "init2"
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
@pytest.mark.parametrize(
"input_params",
[
{},
{"param1": None},
{"param1": ""},
{"param1": 0},
{"param1": []},
{"param1": {}, "param2": "value"},
],
)
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
result = strategy.initialize_parameters(input_params.copy())
for param in strategy.declaration.parameters:
assert param.name in result
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
with pytest.raises(AttributeError):
strategy.initialize_parameters(None)
# ============================================================
# _invoke Tests
# ============================================================
class TestInvoke:
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mock_convert = mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={"converted": True},
)
result = list(
strategy._invoke(
params={"param1": "value"},
user_id="user_1",
conversation_id="conv_1",
app_id="app_1",
message_id="msg_1",
credentials=None,
)
)
assert result == ["msg1", "msg2"]
mock_convert.assert_called_once()
mock_manager.invoke.assert_called_once()
call_kwargs = mock_manager.invoke.call_args.kwargs
assert call_kwargs["tenant_id"] == "tenant_123"
assert call_kwargs["user_id"] == "user_1"
assert call_kwargs["agent_provider"] == "provider_x"
assert call_kwargs["agent_strategy"] == "strategy_x"
assert call_kwargs["agent_params"] == {"converted": True}
assert call_kwargs["conversation_id"] == "conv_1"
assert call_kwargs["app_id"] == "app_1"
assert call_kwargs["message_id"] == "msg_1"
assert call_kwargs["context"] is not None
def test_invoke_with_credentials(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
# Patch PluginInvokeContext to bypass pydantic validation
mock_context = MagicMock()
mocker.patch(
"core.agent.strategy.plugin.PluginInvokeContext",
return_value=mock_context,
)
credentials = MagicMock()
result = list(
strategy._invoke(
params={},
user_id="user_1",
credentials=credentials,
)
)
assert result == []
mock_manager.invoke.assert_called_once()
@pytest.mark.parametrize(
("conversation_id", "app_id", "message_id"),
[
(None, None, None),
("conv", None, None),
(None, "app", None),
(None, None, "msg"),
],
)
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
result = list(
strategy._invoke(
params={},
user_id="user_1",
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
)
assert result == []
mock_manager.invoke.assert_called_once()
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=MagicMock(),
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
side_effect=ValueError("conversion failed"),
)
with pytest.raises(ValueError):
list(strategy._invoke(params={}, user_id="user_1"))
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
with pytest.raises(RuntimeError):
list(strategy._invoke(params={}, user_id="user_1"))

View File

@ -0,0 +1,802 @@
import json
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
import core.agent.base_agent_runner as module
from core.agent.base_agent_runner import BaseAgentRunner
# ==========================================================
# Fixtures
# ==========================================================
@pytest.fixture
def mock_db_session(mocker):
session = mocker.MagicMock()
mocker.patch.object(module.db, "session", session)
return session
@pytest.fixture
def runner(mocker, mock_db_session):
r = BaseAgentRunner.__new__(BaseAgentRunner)
r.tenant_id = "tenant"
r.user_id = "user"
r.agent_thought_count = 0
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
r.app_config = mocker.MagicMock()
r.app_config.app_id = "app1"
r.app_config.agent = None
r.dataset_tools = []
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
r._current_thoughts = []
return r
# ==========================================================
# _repack_app_generate_entity
# ==========================================================
class TestRepack:
def test_sets_empty_if_none(self, runner, mocker):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = None
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == ""
def test_keeps_existing(self, runner, mocker):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = "abc"
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == "abc"
# ==========================================================
# update_prompt_message_tool
# ==========================================================
class TestUpdatePromptTool:
def build_param(self, mocker, **kwargs):
p = mocker.MagicMock()
p.form = kwargs.get("form")
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
p.type = mock_type
p.name = kwargs.get("name", "p1")
p.llm_description = "desc"
p.input_schema = kwargs.get("input_schema")
p.options = kwargs.get("options")
p.required = kwargs.get("required", False)
return p
def test_skip_non_llm(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(mocker, form="NOT_LLM")
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_enum_and_required(self, runner, mocker):
option = mocker.MagicMock(value="opt1")
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
options=[option],
required=True,
)
tool = mocker.MagicMock()
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert "p1" in result.parameters["required"]
def test_skip_file_type_param(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
param.type = module.ToolParameter.ToolParameterType.FILE
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_duplicate_required_not_duplicated(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
required=True,
)
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["required"].count("p1") == 1
# ==========================================================
# create_agent_thought
# ==========================================================
class TestCreateAgentThought:
def test_with_files(self, runner, mock_db_session, mocker):
mock_thought = mocker.MagicMock(id=10)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
assert result == "10"
assert runner.agent_thought_count == 1
def test_without_files(self, runner, mock_db_session, mocker):
mock_thought = mocker.MagicMock(id=11)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
assert result == "11"
# ==========================================================
# save_agent_thought
# ==========================================================
class TestSaveAgentThought:
def setup_agent(self, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1;tool2"
agent.tool_labels = {}
agent.thought = ""
return agent
def test_not_found(self, runner, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError):
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
def test_full_update(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
mock_label = mocker.MagicMock()
mock_label.to_dict.return_value = {"en_US": "label"}
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
usage = mocker.MagicMock(
prompt_tokens=1,
prompt_price_unit=Decimal("0.1"),
prompt_unit_price=Decimal("0.1"),
completion_tokens=2,
completion_price_unit=Decimal("0.2"),
completion_unit_price=Decimal("0.2"),
total_tokens=3,
total_price=Decimal("0.3"),
)
runner.save_agent_thought(
"id",
"tool1;tool2",
{"a": 1},
"thought",
{"b": 2},
{"meta": 1},
"answer",
["f1"],
usage,
)
assert agent.answer == "answer"
assert agent.tokens == 3
assert "tool1" in json.loads(agent.tool_labels_str)
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
agent.tool = "unknown_tool"
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert "unknown_tool" in labels
def test_json_failure_paths(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
bad_obj = MagicMock()
bad_obj.__str__.return_value = "bad"
runner.save_agent_thought(
"id",
None,
bad_obj,
None,
bad_obj,
bad_obj,
None,
[],
None,
)
assert mock_db_session.commit.called
def test_messages_ids_none(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
assert mock_db_session.commit.called
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought(
"id",
None,
{"a": 1},
None,
{"b": 2},
None,
None,
[],
None,
)
assert isinstance(agent.tool_input, str)
assert isinstance(agent.observation, str)
# ==========================================================
# organize_agent_user_prompt
# ==========================================================
class TestOrganizeUserPrompt:
def test_no_files(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = []
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_with_files_no_config(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
msg = mocker.MagicMock(id="1", query="hello")
msg.app_model_config.to_dict.return_value = {}
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
# ==========================================================
# organize_agent_history
# ==========================================================
class TestOrganizeHistory:
def test_empty(self, runner, mock_db_session, mocker):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
result = runner.organize_agent_history([])
assert result == []
def test_with_answer_only(self, runner, mock_db_session, mocker):
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
def test_skip_current_message(self, runner, mock_db_session, mocker):
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert result == []
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input="invalid",
observation="invalid",
thought="thinking",
)
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(tool=";", thought="thinking")
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input=json.dumps({"tool1": {"x": 1}}),
observation=json.dumps({"tool1": "obs"}),
thought="thinking",
)
msg = mocker.MagicMock(
id="m100",
agent_thoughts=[thought],
answer=None,
app_model_config=None,
)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
# ==========================================================
# _convert_tool_to_prompt_message_tool (new coverage)
# ==========================================================
class TestConvertToolToPromptMessageTool:
def test_basic_conversion(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
runtime_param = mocker.MagicMock()
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
runtime_param.name = "param1"
runtime_param.llm_description = "desc"
runtime_param.required = True
runtime_param.input_schema = None
runtime_param.options = None
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
runtime_param.type = mock_type
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
def test_full_conversion_multiple_params(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
# LLM param with input_schema override
param1 = mocker.MagicMock()
param1.form = module.ToolParameter.ToolParameterForm.LLM
param1.name = "p1"
param1.llm_description = "desc"
param1.required = True
param1.input_schema = {"type": "integer"}
param1.options = None
param1.type = mocker.MagicMock()
# SYSTEM_FILES param should be skipped
param2 = mocker.MagicMock()
param2.form = module.ToolParameter.ToolParameterForm.LLM
param2.name = "file_param"
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
# ==========================================================
# _init_prompt_tools additional branches
# ==========================================================
class TestInitPromptToolsExtended:
def test_agent_tool_branch(self, runner, mocker):
agent_tool = mocker.MagicMock(tool_name="agent_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
tools, prompts = runner._init_prompt_tools()
assert "agent_tool" in tools
def test_exception_in_conversion(self, runner, mocker):
agent_tool = mocker.MagicMock(tool_name="bad_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
tools, prompts = runner._init_prompt_tools()
assert tools == {}
# ==========================================================
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
# ==========================================================
class TestAdditionalCoverage:
def test_update_prompt_with_input_schema(self, runner, mocker):
tool = mocker.MagicMock()
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "p1"
param.required = False
param.llm_description = "desc"
param.options = None
param.input_schema = {"type": "number"}
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
param.type = mock_type
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["p1"]["type"] == "number"
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {"tool1": {"en_US": "existing"}}
agent.thought = ""
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert labels["tool1"]["en_US"] == "existing"
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
assert agent.tool_meta_str == "meta_string"
def test_convert_dataset_retriever_tool(self, runner, mocker):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
param = mocker.MagicMock()
param.name = "query"
param.llm_description = "desc"
param.required = True
ds_tool.get_runtime_parameters.return_value = [param]
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
msg = mocker.MagicMock(id="1", query="hello")
msg.app_model_config.to_dict.return_value = {}
result = runner.organize_agent_user_prompt(msg)
assert result is not None
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(tool=None, thought="thinking")
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1;tool2",
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
thought="thinking",
)
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
# ================= Additional Surgical Coverage =================
def test_convert_tool_select_enum_branch(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = True
param.llm_description = "desc"
param.input_schema = None
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool is not None
class TestConvertDatasetRetrieverTool:
def test_required_param_added(self, runner, mocker):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
param = mocker.MagicMock()
param.name = "query"
param.llm_description = "desc"
param.required = True
ds_tool.get_runtime_parameters.return_value = [param]
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
class TestBaseAgentRunnerInit:
def test_init_sets_stream_tool_call_and_files(self, mocker):
session = mocker.MagicMock()
session.query.return_value.where.return_value.count.return_value = 2
mocker.patch.object(module.db, "session", session)
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
llm = mocker.MagicMock()
llm.get_model_schema.return_value = mocker.MagicMock(
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
)
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
app_config = mocker.MagicMock()
app_config.app_id = "app1"
app_config.agent = None
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
runner = BaseAgentRunner(
tenant_id="tenant",
application_generate_entity=app_generate,
conversation=mocker.MagicMock(),
app_config=app_config,
model_config=mocker.MagicMock(),
config=mocker.MagicMock(),
queue_manager=mocker.MagicMock(),
message=message,
user_id="user",
model_instance=model_instance,
)
assert runner.stream_tool_call is True
assert runner.files == ["file1"]
assert runner.dataset_tools == ["ds_tool"]
assert runner.agent_thought_count == 2
class TestBaseAgentRunnerCoverage:
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = "NOT_LLM"
param.type = mocker.MagicMock()
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool.parameters["properties"] == {}
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
dataset_tool = mocker.MagicMock()
dataset_tool.entity.identity.name = "ds"
runner.dataset_tools = [dataset_tool]
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
tools, prompt_tools = runner._init_prompt_tools()
assert tools["ds"] == dataset_tool
assert len(prompt_tools) == 1
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
tool = mocker.MagicMock()
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = False
param.llm_description = "desc"
param.input_schema = None
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
tool_input = {"a": 1}
observation = {"b": 2}
tool_meta = {"c": 3}
real_dumps = json.dumps
def dumps_side_effect(value, *args, **kwargs):
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
raise TypeError("fail")
return real_dumps(value, *args, **kwargs)
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
runner.save_agent_thought(
"id",
"tool1",
tool_input,
None,
observation,
tool_meta,
None,
[],
None,
)
assert isinstance(agent.tool_input, str)
assert isinstance(agent.observation, str)
assert isinstance(agent.tool_meta_str, str)
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1;;"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert "" not in labels
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
system_message = module.SystemPromptMessage(content="sys")
result = runner.organize_agent_history([system_message])
assert system_message in result
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input=None,
observation=None,
thought="thinking",
)
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
mocker.patch.object(
runner,
"organize_agent_user_prompt",
return_value=module.UserPromptMessage(content="user"),
)
result = runner.organize_agent_history([])
assert any(isinstance(item, module.ToolPromptMessage) for item in result)

View File

@ -0,0 +1,551 @@
import json
from unittest.mock import MagicMock
import pytest
from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.entities import AgentScratchpadUnit
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.agent.exc import AgentMaxIterationError
class DummyRunner(CotAgentRunner):
"""Concrete implementation for testing abstract methods."""
def __init__(self, **kwargs):
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
for k, v in kwargs.items():
setattr(self, k, v)
# Minimal required defaults
self.history_prompt_messages = []
self.memory = None
def _organize_prompt_messages(self):
return []
@pytest.fixture
def runner(mocker):
# Prevent BaseAgentRunner __init__ from hitting database
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
return_value=[],
)
# Prepare required constructor dependencies for BaseAgentRunner
application_generate_entity = MagicMock()
application_generate_entity.model_conf = MagicMock()
application_generate_entity.model_conf.stop = []
application_generate_entity.model_conf.provider = "openai"
application_generate_entity.model_conf.parameters = {}
application_generate_entity.trace_manager = None
application_generate_entity.invoke_from = "test"
app_config = MagicMock()
app_config.agent = MagicMock()
app_config.agent.max_iteration = 1
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
model_instance.invoke_llm.return_value = []
model_config = MagicMock()
model_config.model = "test-model"
queue_manager = MagicMock()
message = MagicMock()
runner = DummyRunner(
tenant_id="tenant",
application_generate_entity=application_generate_entity,
conversation=MagicMock(),
app_config=app_config,
model_config=model_config,
config=MagicMock(),
queue_manager=queue_manager,
message=message,
user_id="user",
model_instance=model_instance,
)
# Patch internal methods to isolate behavior
runner._repack_app_generate_entity = MagicMock()
runner._init_prompt_tools = MagicMock(return_value=({}, []))
runner.recalc_llm_max_tokens = MagicMock()
runner.create_agent_thought = MagicMock(return_value="thought-id")
runner.save_agent_thought = MagicMock()
runner.update_prompt_message_tool = MagicMock()
runner.agent_callback = None
runner.memory = None
runner.history_prompt_messages = []
return runner
class TestFillInputs:
@pytest.mark.parametrize(
("instruction", "inputs", "expected"),
[
("Hello {{name}}", {"name": "John"}, "Hello John"),
("No placeholders", {"name": "John"}, "No placeholders"),
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
("{{x}}", {"x": None}, "None"),
("", {"x": "y"}, ""),
],
)
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
assert result == expected
class TestConvertDictToAction:
def test_convert_valid_dict(self, runner):
action_dict = {"action": "test", "action_input": {"a": 1}}
action = runner._convert_dict_to_action(action_dict)
assert action.action_name == "test"
assert action.action_input == {"a": 1}
def test_convert_missing_keys(self, runner):
with pytest.raises(KeyError):
runner._convert_dict_to_action({"invalid": 1})
class TestFormatAssistantMessage:
def test_format_assistant_message_multiple_scratchpads(self, runner):
sp1 = AgentScratchpadUnit(
agent_response="resp1",
thought="thought1",
action_str="action1",
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
observation="obs1",
)
sp2 = AgentScratchpadUnit(
agent_response="final",
thought="",
action_str="",
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
observation=None,
)
result = runner._format_assistant_message([sp1, sp2])
assert "Final Answer:" in result
def test_format_with_final(self, runner):
scratchpad = AgentScratchpadUnit(
agent_response="Done",
thought="",
action_str="",
action=None,
observation=None,
)
# Simulate final state via action name
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
result = runner._format_assistant_message([scratchpad])
assert "Final Answer" in result
def test_format_with_action_and_observation(self, runner):
scratchpad = AgentScratchpadUnit(
agent_response="resp",
thought="thinking",
action_str="action",
action=None,
observation="obs",
)
# Non-final state: provide a non-final action
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
result = runner._format_assistant_message([scratchpad])
assert "Thought:" in result
assert "Action:" in result
assert "Observation:" in result
class TestHandleInvokeAction:
def test_handle_invoke_action_tool_not_present(self, runner):
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
response, meta = runner._handle_invoke_action(action, {}, [])
assert "there is not a tool named" in response
def test_tool_with_json_string_args(self, runner, mocker):
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("result", [], MagicMock(to_dict=lambda: {})),
)
response, meta = runner._handle_invoke_action(action, tool_instances, [])
assert response == "result"
class TestOrganizeHistoricPromptMessages:
def test_empty_history(self, runner, mocker):
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
return_value=[],
)
result = runner._organize_historic_prompt_messages([])
assert result == []
class TestRun:
def test_run_handles_empty_parser_output(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {}))
assert isinstance(results, list)
def test_run_with_action_and_tool_invocation(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.agent_callback = None
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_respects_max_iteration_boundary(self, runner, mocker):
runner.app_config.agent.max_iteration = 1
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.agent_callback = None
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_basic_flow(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {"name": "John"}))
assert results
def test_run_max_iteration_error(self, runner, mocker):
runner.app_config.agent.max_iteration = 0
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {}))
def test_run_increase_usage_aggregation(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
runner.app_config.agent.max_iteration = 2
usage_1 = LLMUsage.empty_usage()
usage_1.prompt_tokens = 1
usage_1.completion_tokens = 1
usage_1.total_tokens = 2
usage_1.prompt_price = 1
usage_1.completion_price = 1
usage_1.total_price = 2
usage_2 = LLMUsage.empty_usage()
usage_2.prompt_tokens = 1
usage_2.completion_tokens = 1
usage_2.total_tokens = 2
usage_2.prompt_price = 1
usage_2.completion_price = 1
usage_2.total_price = 2
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
handle_output = mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
side_effect=[
[action],
[],
],
)
def _handle_side_effect(chunks, usage_dict):
call_index = handle_output.call_count
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
return [action] if call_index == 1 else []
handle_output.side_effect = _handle_side_effect
runner.model_instance.invoke_llm = MagicMock(return_value=[])
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
fake_prompt_tool = MagicMock()
fake_prompt_tool.name = "tool"
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
results = list(runner.run(message, "query", {}))
final_usage = results[-1].delta.usage
assert final_usage is not None
assert final_usage.prompt_tokens == 2
assert final_usage.completion_tokens == 2
assert final_usage.total_tokens == 4
assert final_usage.prompt_price == 2
assert final_usage.completion_price == 2
assert final_usage.total_price == 4
def test_run_when_no_action_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == ""
def test_run_usage_missing_key_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
runner.model_instance.invoke_llm = MagicMock(return_value=[])
list(runner.run(message, "query", {}))
def test_run_prompt_tool_update_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
# First iteration → action
# Second iteration → no action (empty list)
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
side_effect=[[action], []],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.app_config.agent.max_iteration = 5
fake_prompt_tool = MagicMock()
fake_prompt_tool.name = "tool"
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
runner.update_prompt_message_tool = MagicMock()
runner.agent_callback = None
list(runner.run(message, "query", {}))
runner.update_prompt_message_tool.assert_called_once()
def test_historic_with_assistant_and_tool_calls(self, runner):
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
assistant = AssistantPromptMessage(content="thinking")
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
runner.history_prompt_messages = [assistant, tool_msg]
result = runner._organize_historic_prompt_messages([])
assert isinstance(result, list)
def test_historic_final_flush_branch(self, runner):
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
assistant = AssistantPromptMessage(content="final")
runner.history_prompt_messages = [assistant]
result = runner._organize_historic_prompt_messages([])
assert isinstance(result, list)
class TestInitReactState:
def test_init_react_state_resets_state(self, runner, mocker):
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
runner._agent_scratchpad = ["old"]
runner._query = "old"
runner._init_react_state("new-query")
assert runner._query == "new-query"
assert runner._agent_scratchpad == []
assert runner._historic_prompt_messages == ["historic"]
class TestHandleInvokeActionExtended:
def test_tool_with_invalid_json_string_args(self, runner, mocker):
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
)
message_file_ids = []
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
assert response == "ok"
assert message_file_ids == ["file1"]
runner.queue_manager.publish.assert_called()
class TestFillInputsEdgeCases:
def test_fill_inputs_with_empty_inputs(self, runner):
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
assert result == "Hello {{x}}"
def test_fill_inputs_with_exception_in_replace(self, runner):
class BadValue:
def __str__(self):
raise Exception("fail")
# Should silently continue on exception
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
assert result == "Hello {{x}}"
class TestOrganizeHistoricPromptMessagesExtended:
def test_user_message_flushes_scratchpad(self, runner, mocker):
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
user_message = UserPromptMessage(content="Hi")
runner.history_prompt_messages = [user_message]
mock_transform = mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
)
mock_transform.return_value.get_prompt.return_value = ["final"]
result = runner._organize_historic_prompt_messages([])
assert result == ["final"]
def test_tool_message_without_scratchpad_raises(self, runner):
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
with pytest.raises(NotImplementedError):
runner._organize_historic_prompt_messages([])
def test_agent_history_transform_invocation(self, runner, mocker):
mock_transform = MagicMock()
mock_transform.get_prompt.return_value = []
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
return_value=mock_transform,
)
runner.history_prompt_messages = []
result = runner._organize_historic_prompt_messages([])
assert result == []
class TestRunAdditionalBranches:
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=["thinking"],
)
results = list(runner.run(message, "query", {}))
assert any(hasattr(r, "delta") for r in results)
def test_run_with_final_answer_action_string(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "done"
def test_run_with_final_answer_action_dict(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert json.loads(results[-1].delta.message.content) == {"a": 1}
def test_run_with_string_final_answer(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
# Remove invalid branch: Pydantic enforces str|dict for action_input
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "12345"

View File

@ -0,0 +1,215 @@
from unittest.mock import MagicMock, patch
import pytest
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
from tests.unit_tests.core.agent.conftest import (
DummyAgentConfig,
DummyAppConfig,
DummyTool,
)
from tests.unit_tests.core.agent.conftest import (
DummyPromptEntity as DummyPrompt,
)
class DummyFileUploadConfig:
def __init__(self, image_config=None):
self.image_config = image_config
class DummyImageConfig:
def __init__(self, detail=None):
self.detail = detail
class DummyGenerateEntity:
def __init__(self, file_upload_config=None):
self.file_upload_config = file_upload_config
class DummyUnit:
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
self._final = final
self.thought = thought
self.action_str = action_str
self.observation = observation
self.agent_response = agent_response
def is_final(self):
return self._final
@pytest.fixture
def runner():
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
runner._instruction = "test_instruction"
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
runner._query = "user query"
runner._agent_scratchpad = []
runner.files = []
runner.application_generate_entity = DummyGenerateEntity()
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
return runner
class TestOrganizeSystemPrompt:
def test_organize_system_prompt_success(self, runner, mocker):
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
mocker.patch(
"core.agent.cot_chat_agent_runner.jsonable_encoder",
return_value=[{"name": "tool1"}, {"name": "tool2"}],
)
result = runner._organize_system_prompt()
assert "test_instruction" in result.content
assert "tool1" in result.content
assert "tool2" in result.content
assert "tool1, tool2" in result.content
def test_organize_system_prompt_missing_agent(self, runner):
runner.app_config = DummyAppConfig(agent=None)
with pytest.raises(AssertionError):
runner._organize_system_prompt()
def test_organize_system_prompt_missing_prompt(self, runner):
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
with pytest.raises(AssertionError):
runner._organize_system_prompt()
class TestOrganizeUserQuery:
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
def test_organize_user_query_no_files(self, runner, files):
runner.files = files
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert result[0].content == "query"
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_content = ImagePromptMessageContent(
url="http://test",
format="png",
mime_type="image/png",
)
mock_to_prompt.return_value = mock_content
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
runner.files = ["file1"]
runner.application_generate_entity = DummyGenerateEntity(None)
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
assert mock_content in result[0].content
mock_to_prompt.assert_called_once_with(
"file1",
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
)
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_content = ImagePromptMessageContent(
url="http://test",
format="png",
mime_type="image/png",
)
mock_to_prompt.return_value = mock_content
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
runner.files = ["file1"]
image_config = DummyImageConfig(detail="high")
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
assert mock_content in result[0].content
mock_to_prompt.assert_called_once_with(
"file1",
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
)
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
runner.files = ["file1"]
runner.application_generate_entity = DummyGenerateEntity(None)
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
class TestOrganizePromptMessages:
def test_no_scratchpad(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
result = runner._organize_prompt_messages()
assert "system" in result
assert "query" in result
runner._organize_historic_prompt_messages.assert_called_once()
def test_with_final_scratchpad(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
unit = DummyUnit(final=True, agent_response="done")
runner._agent_scratchpad = [unit]
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Final Answer: done" in combined
def test_with_thought_action_observation(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
unit = DummyUnit(
final=False,
thought="thinking",
action_str="action",
observation="observe",
)
runner._agent_scratchpad = [unit]
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Thought: thinking" in combined
assert "Action: action" in combined
assert "Observation: observe" in combined
def test_multiple_units_mixed(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
units = [
DummyUnit(final=False, thought="t1"),
DummyUnit(final=True, agent_response="done"),
]
runner._agent_scratchpad = units
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Thought: t1" in combined
assert "Final Answer: done" in combined

View File

@ -0,0 +1,234 @@
import json
import pytest
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
# -----------------------------
# Fixtures
# -----------------------------
@pytest.fixture
def runner(mocker, dummy_tool_factory):
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
runner._instruction = "Test instruction"
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
runner._query = "What is Python?"
runner._agent_scratchpad = []
mocker.patch(
"core.agent.cot_completion_agent_runner.jsonable_encoder",
side_effect=lambda tools: [{"name": t.name} for t in tools],
)
return runner
# ======================================================
# _organize_instruction_prompt Tests
# ======================================================
class TestOrganizeInstructionPrompt:
def test_success_all_placeholders(
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
):
template = (
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
)
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
result = runner._organize_instruction_prompt()
assert "Test instruction" in result
assert "toolA" in result
assert "toolB" in result
tools_payload = json.loads(result.split(" | ")[1])
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
def test_agent_none_raises(self, runner, dummy_app_config_factory):
runner.app_config = dummy_app_config_factory(agent=None)
with pytest.raises(ValueError, match="Agent configuration is not set"):
runner._organize_instruction_prompt()
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
with pytest.raises(ValueError, match="prompt entity is not set"):
runner._organize_instruction_prompt()
# ======================================================
# _organize_historic_prompt Tests
# ======================================================
class TestOrganizeHistoricPrompt:
def test_with_user_and_assistant_string(self, runner, mocker):
user_msg = UserPromptMessage(content="Hello")
assistant_msg = AssistantPromptMessage(content="Hi there")
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[user_msg, assistant_msg],
)
result = runner._organize_historic_prompt()
assert "Question: Hello" in result
assert "Hi there" in result
def test_assistant_list_with_text_content(self, runner, mocker):
text_content = TextPromptMessageContent(data="Partial answer")
assistant_msg = AssistantPromptMessage(content=[text_content])
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[assistant_msg],
)
result = runner._organize_historic_prompt()
assert "Partial answer" in result
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
assistant_msg = AssistantPromptMessage(content=[non_text_content])
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[assistant_msg],
)
result = runner._organize_historic_prompt()
assert result == ""
def test_empty_history(self, runner, mocker):
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[],
)
result = runner._organize_historic_prompt()
assert result == ""
# ======================================================
# _organize_prompt_messages Tests
# ======================================================
class TestOrganizePromptMessages:
def test_full_flow_with_scratchpad(
self,
runner,
mocker,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
dummy_scratchpad_unit_factory,
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
runner._agent_scratchpad = [
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
]
result = runner._organize_prompt_messages()
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
content = result[0].content
assert "History" in content
assert "Thought: Thinking" in content
assert "Action: Act" in content
assert "Observation: Obs" in content
assert "Final Answer: Done" in content
assert "Question: What is Python?" in content
def test_no_scratchpad(
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
runner._agent_scratchpad = None
result = runner._organize_prompt_messages()
assert "Question: What is Python?" in result[0].content
@pytest.mark.parametrize(
("thought", "action", "observation"),
[
("T", None, None),
("T", "A", None),
("T", None, "O"),
],
)
def test_partial_scratchpad_units(
self,
runner,
mocker,
thought,
action,
observation,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
dummy_scratchpad_unit_factory,
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
runner._agent_scratchpad = [
dummy_scratchpad_unit_factory(
final=False,
thought=thought,
action_str=action,
observation=observation,
)
]
result = runner._organize_prompt_messages()
content = result[0].content
assert "Thought:" in content
if action:
assert "Action:" in content
if observation:
assert "Observation:" in content

View File

@ -0,0 +1,452 @@
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueMessageFileEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.nodes.agent.exc import AgentMaxIterationError
# ==============================
# Dummy Helper Classes
# ==============================
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
usage = LLMUsage.empty_usage()
usage.prompt_tokens = pt
usage.completion_tokens = ct
usage.total_tokens = tt
usage.prompt_price = 0
usage.completion_price = 0
usage.total_price = 0
return usage
class DummyMessage:
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
self.content: str | None = content
self.tool_calls: list[Any] = tool_calls or []
class DummyDelta:
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
self.message: DummyMessage | None = message
self.usage: LLMUsage | None = usage
class DummyChunk:
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
class DummyResult:
def __init__(
self,
message: DummyMessage | None = None,
usage: LLMUsage | None = None,
prompt_messages: list[DummyMessage] | None = None,
):
self.message: DummyMessage | None = message
self.usage: LLMUsage | None = usage
self.prompt_messages: list[DummyMessage] = prompt_messages or []
self.system_fingerprint: str = ""
# ==============================
# Fixtures
# ==============================
@pytest.fixture
def runner(mocker):
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
return_value=None,
)
# Patch streaming chunk models to avoid validation on dummy message objects
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
app_config = MagicMock()
app_config.agent = MagicMock(max_iteration=2)
app_config.prompt_template = MagicMock(simple_prompt_template="system")
application_generate_entity = MagicMock()
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
application_generate_entity.trace_manager = MagicMock()
application_generate_entity.invoke_from = "test"
application_generate_entity.app_config = MagicMock(app_id="app")
application_generate_entity.file_upload_config = None
queue_manager = MagicMock()
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
message = MagicMock(id="msg1")
conversation = MagicMock(id="conv1")
runner = FunctionCallAgentRunner(
tenant_id="tenant",
application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config,
model_config=MagicMock(),
config=MagicMock(),
queue_manager=queue_manager,
message=message,
user_id="user",
model_instance=model_instance,
)
# Manually inject required attributes normally set by BaseAgentRunner
runner.tenant_id = "tenant"
runner.application_generate_entity = application_generate_entity
runner.conversation = conversation
runner.app_config = app_config
runner.model_config = MagicMock()
runner.config = MagicMock()
runner.queue_manager = queue_manager
runner.message = message
runner.user_id = "user"
runner.model_instance = model_instance
runner.stream_tool_call = False
runner.memory = None
runner.history_prompt_messages = []
runner._current_thoughts = []
runner.files = []
runner.agent_callback = MagicMock()
runner._init_prompt_tools = MagicMock(return_value=({}, []))
runner.create_agent_thought = MagicMock(return_value="thought1")
runner.save_agent_thought = MagicMock()
runner.recalc_llm_max_tokens = MagicMock()
runner.update_prompt_message_tool = MagicMock()
return runner
# ==============================
# Tool Call Checks
# ==============================
class TestToolCallChecks:
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_tool_calls(self, runner, tool_calls, expected):
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_tool_calls(chunk) is expected
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_blocking_tool_calls(result) is expected
# ==============================
# Extract Tool Calls
# ==============================
class TestExtractToolCalls:
def test_extract_tool_calls_with_valid_json(self, runner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_tool_calls(chunk)
assert calls == [("1", "tool", {"a": 1})]
def test_extract_tool_calls_empty_arguments(self, runner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = ""
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_tool_calls(chunk)
assert calls == [("1", "tool", {})]
def test_extract_blocking_tool_calls(self, runner):
tool_call = MagicMock()
tool_call.id = "2"
tool_call.function.name = "block"
tool_call.function.arguments = json.dumps({"x": 2})
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_blocking_tool_calls(result)
assert calls == [("2", "block", {"x": 2})]
# ==============================
# System Message Initialization
# ==============================
class TestInitSystemMessage:
def test_init_system_message_empty_prompt_messages(self, runner):
result = runner._init_system_message("system", [])
assert len(result) == 1
def test_init_system_message_insert_at_start(self, runner):
msgs = [MagicMock()]
result = runner._init_system_message("system", msgs)
assert result[0].content == "system"
def test_init_system_message_no_template(self, runner):
result = runner._init_system_message("", [])
assert result == []
# ==============================
# Organize User Query
# ==============================
class TestOrganizeUserQuery:
def test_without_files(self, runner):
result = runner._organize_user_query("query", [])
assert len(result) == 1
def test_with_none_query(self, runner):
result = runner._organize_user_query(None, [])
assert len(result) == 1
def test_with_files_uses_image_detail_config(self, runner, mocker):
file_content = TextPromptMessageContent(data="file-content")
mock_to_prompt = mocker.patch(
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
return_value=file_content,
)
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
runner.files = ["file1"]
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
# ==============================
# Clear User Prompt Images
# ==============================
class TestClearUserPromptImageMessages:
def test_clear_text_and_image_content(self, runner):
text = MagicMock()
text.type = "text"
text.data = "hello"
image = MagicMock()
image.type = "image"
image.data = "img"
user_msg = MagicMock()
user_msg.__class__.__name__ = "UserPromptMessage"
user_msg.content = [text, image]
result = runner._clear_user_prompt_image_messages([user_msg])
assert isinstance(result, list)
def test_clear_includes_file_placeholder(self, runner):
text = TextPromptMessageContent(data="hello")
image = ImagePromptMessageContent(format="url", mime_type="image/png")
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
user_msg = UserPromptMessage(content=[text, image, document])
result = runner._clear_user_prompt_image_messages([user_msg])
assert result[0].content == "hello\n[image]\n[file]"
# ==============================
# Run Method Tests
# ==============================
class TestRunMethod:
def test_run_non_streaming_no_tool_calls(self, runner):
message = MagicMock(id="m1")
dummy_message = DummyMessage(content="hello")
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
runner.queue_manager.publish.assert_called()
queue_calls = runner.queue_manager.publish.call_args_list
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
def test_run_streaming_branch(self, runner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
content = [TextPromptMessageContent(data="hi")]
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
def generator():
yield chunk
runner.model_instance.invoke_llm.return_value = generator()
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_streaming_tool_calls_list_content(self, runner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
content = [TextPromptMessageContent(data="hi")]
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
def generator():
yield chunk
final_message = DummyMessage(content="done", tool_calls=[])
final_result = DummyResult(message=final_message, usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_non_streaming_list_content(self, runner):
message = MagicMock(id="m1")
content = [TextPromptMessageContent(data="hi")]
dummy_message = DummyMessage(content=content)
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
message = MagicMock(id="m1")
runner.stream_tool_call = True
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
def generator():
yield chunk
runner.model_instance.invoke_llm.return_value = generator()
real_dumps = json.dumps
def flaky_dumps(obj, *args, **kwargs):
if kwargs.get("ensure_ascii") is False:
return real_dumps(obj, *args, **kwargs)
raise TypeError("boom")
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_with_missing_tool_instance(self, runner):
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "missing"
tool_call.function.arguments = json.dumps({})
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
final_message = DummyMessage(content="done", tool_calls=[])
final_result = DummyResult(message=final_message, usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [result, final_result]
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_with_tool_instance_and_files(self, runner, mocker):
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [result, final_result]
tool_instance = MagicMock()
prompt_tool = MagicMock()
prompt_tool.name = "tool"
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
tool_invoke_meta = MagicMock()
tool_invoke_meta.to_dict.return_value = {"ok": True}
mocker.patch(
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", ["file1"], tool_invoke_meta),
)
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
assert any(
isinstance(call.args[0], QueueMessageFileEvent)
and call.args[0].message_file_id == "file1"
and call.args[1] == PublishFrom.APPLICATION_MANAGER
for call in runner.queue_manager.publish.call_args_list
)
def test_run_max_iteration_error(self, runner):
runner.app_config.agent.max_iteration = 0
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = "{}"
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query"))

View File

@ -0,0 +1,324 @@
"""Unit tests for core.agent.plugin_entities.
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
Pydantic ValidationError behavior and pytest fixtures for validation and
mocking; ensure entity invariants and validation rules remain stable.
"""
import pytest
from pydantic import ValidationError
from core.agent.plugin_entities import (
AgentFeature,
AgentProviderEntityWithPlugin,
AgentStrategyEntity,
AgentStrategyIdentity,
AgentStrategyParameter,
AgentStrategyProviderEntity,
AgentStrategyProviderIdentity,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
# =========================================================
# Fixtures
# =========================================================
@pytest.fixture
def mock_identity(mocker):
return mocker.MagicMock(spec=AgentStrategyIdentity)
@pytest.fixture
def mock_provider_identity(mocker):
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
# =========================================================
# AgentStrategyParameterType Tests
# =========================================================
class TestAgentStrategyParameterType:
@pytest.mark.parametrize(
"enum_member",
list(AgentStrategyParameter.AgentStrategyParameterType),
)
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.as_normal_type",
return_value="normalized",
)
result = enum_member.as_normal_type()
mock_func.assert_called_once_with(enum_member)
assert result == "normalized"
def test_as_normal_type_propagates_exception(self, mocker) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.as_normal_type",
side_effect=RuntimeError("boom"),
)
with pytest.raises(RuntimeError):
enum_member.as_normal_type()
@pytest.mark.parametrize(
("enum_member", "value"),
[
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
],
)
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
return_value="casted",
)
result = enum_member.cast_value(value)
mock_func.assert_called_once_with(enum_member, value)
assert result == "casted"
def test_cast_value_propagates_exception(self, mocker) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
side_effect=ValueError("invalid"),
)
with pytest.raises(ValueError):
enum_member.cast_value("bad")
# =========================================================
# AgentStrategyParameter Tests
# =========================================================
class TestAgentStrategyParameter:
def test_valid_creation_minimal(self) -> None:
# bypass base PluginParameter required fields using model_construct
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
help=None,
)
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
assert param.help is None
def test_valid_creation_with_help(self) -> None:
help_obj = I18nObject(en_US="test")
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
help=help_obj,
)
assert param.help == help_obj
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
with pytest.raises(ValidationError) as exc_info:
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
return_value="frontend",
)
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
result = param.init_frontend_parameter("value")
mock_func.assert_called_once_with(param, param.type, "value")
assert result == "frontend"
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
side_effect=RuntimeError("error"),
)
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
with pytest.raises(RuntimeError):
param.init_frontend_parameter("value")
# =========================================================
# AgentStrategyProviderEntity Tests
# =========================================================
class TestAgentStrategyProviderEntity:
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(
identity=mock_provider_identity,
plugin_id="plugin-123",
)
assert entity.plugin_id == "plugin-123"
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(
identity=mock_provider_identity,
plugin_id="",
)
assert entity.plugin_id == ""
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
assert entity.plugin_id is None
def test_invalid_identity_raises(self) -> None:
with pytest.raises(ValidationError):
AgentStrategyProviderEntity(identity="invalid")
# =========================================================
# AgentStrategyEntity Tests
# =========================================================
class TestAgentStrategyEntity:
def test_parameters_default_empty(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
)
assert entity.parameters == []
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=None,
)
assert entity.parameters == []
def test_parameters_preserved(self, mock_identity) -> None:
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=[param],
)
assert entity.parameters == [param]
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters="invalid",
)
@pytest.mark.parametrize(
"features",
[
None,
[],
[AgentFeature.HISTORY_MESSAGES],
],
)
def test_features_valid(self, mock_identity, features) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
features=features,
)
assert entity.features == features
def test_invalid_features_type_raises(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
features="invalid",
)
def test_output_schema_and_meta_version(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
output_schema={"type": "object"},
meta_version="v1",
)
assert entity.output_schema == {"type": "object"}
assert entity.meta_version == "v1"
def test_missing_required_fields_raise(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(identity=mock_identity)
# =========================================================
# AgentProviderEntityWithPlugin Tests
# =========================================================
class TestAgentProviderEntityWithPlugin:
def test_default_strategies_empty(self, mock_provider_identity) -> None:
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
assert entity.strategies == []
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
strategy = AgentStrategyEntity.model_construct(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=[],
)
entity = AgentProviderEntityWithPlugin(
identity=mock_provider_identity,
strategies=[strategy],
)
assert entity.strategies == [strategy]
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
with pytest.raises(ValidationError):
AgentProviderEntityWithPlugin(
identity=mock_provider_identity,
strategies="invalid",
)
# =========================================================
# Inheritance Smoke Tests
# =========================================================
class TestInheritanceBehavior:
def test_agent_strategy_identity_inherits(self) -> None:
assert issubclass(AgentStrategyIdentity, ToolIdentity)
def test_agent_strategy_provider_identity_inherits(self) -> None:
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)