mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: unit test cases core.agent module (#32474)
This commit is contained in:
parent
c59685748c
commit
d5724aebde
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class DummyPromptEntity:
|
||||
def __init__(self, first_prompt):
|
||||
self.first_prompt = first_prompt
|
||||
|
||||
|
||||
class DummyAgentConfig:
|
||||
def __init__(self, prompt_entity=None):
|
||||
self.prompt = prompt_entity
|
||||
|
||||
|
||||
class DummyAppConfig:
|
||||
def __init__(self, agent=None):
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class DummyScratchpadUnit:
|
||||
def __init__(
|
||||
self,
|
||||
final=False,
|
||||
thought=None,
|
||||
action_str=None,
|
||||
observation=None,
|
||||
agent_response=None,
|
||||
):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tool_factory():
|
||||
def _factory(name):
|
||||
return DummyTool(name)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_prompt_entity_factory():
|
||||
def _factory(first_prompt):
|
||||
return DummyPromptEntity(first_prompt)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent_config_factory():
|
||||
def _factory(prompt_entity=None):
|
||||
return DummyAgentConfig(prompt_entity)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_app_config_factory():
|
||||
def _factory(agent=None):
|
||||
return DummyAppConfig(agent)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_scratchpad_unit_factory():
|
||||
def _factory(**kwargs):
|
||||
return DummyScratchpadUnit(**kwargs)
|
||||
|
||||
return _factory
|
||||
@ -1,70 +1,255 @@
|
||||
"""Unit tests for CotAgentOutputParser.
|
||||
|
||||
Verifies expected parsing behavior for streaming content and JSON payloads,
|
||||
including edge cases such as empty/non-string content and malformed JSON.
|
||||
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
|
||||
model output structures. Implementation under test:
|
||||
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
@pytest.fixture
|
||||
def mock_action_class(mocker):
|
||||
mock_action = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
|
||||
mock_action,
|
||||
)
|
||||
return mock_action
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_dict():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_chunk():
|
||||
def _make_chunk(content=None, usage=None):
|
||||
delta = SimpleNamespace(
|
||||
message=SimpleNamespace(content=content),
|
||||
usage=usage,
|
||||
)
|
||||
return SimpleNamespace(delta=delta)
|
||||
|
||||
return _make_chunk
|
||||
|
||||
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
# ============================================================
|
||||
# Test Suite
|
||||
# ============================================================
|
||||
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
class TestCotAgentOutputParser:
|
||||
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
|
||||
|
||||
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
|
||||
lightweight chunk/action doubles. Invariants: non-string/empty content
|
||||
yields no output, usage gets recorded when provided, and valid action JSON
|
||||
results in Action instantiation. Usage: invoke via pytest (e.g.,
|
||||
`pytest -k TestCotAgentOutputParser`).
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Basic streaming & usage
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("hello world")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "".join(result) == "hello world"
|
||||
|
||||
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk(None)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
|
||||
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_usage_update(self, make_chunk, usage_dict) -> None:
|
||||
usage_data = {"tokens": 99}
|
||||
chunks = [make_chunk("abc", usage=usage_data)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert usage_dict["usage"] == usage_data
|
||||
|
||||
# --------------------------------------------------------
|
||||
# JSON parsing (direct + streaming)
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '{"action": "search", "input": "query"}'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
|
||||
|
||||
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '[{"action": "lookup", "input": "abc"}]'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
|
||||
content = '{"foo": "bar"}'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# Expect the serialized JSON to be yielded as a single element.
|
||||
assert result == [json.dumps({"foo": "bar"})]
|
||||
|
||||
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
|
||||
content = "{invalid json}"
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert any("invalid json" in str(r) for r in result)
|
||||
|
||||
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
chunks = [
|
||||
make_chunk('{"action": '),
|
||||
make_chunk('"multi", '),
|
||||
make_chunk('"input": "step"}'),
|
||||
]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
|
||||
|
||||
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('{"foo": "bar"')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('{"foo": "bar"' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Code block JSON extraction
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = """```json
|
||||
{"action": "lookup", "input": "abc"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
# Multiple JSON objects inside single code fence (invalid combined JSON)
|
||||
# Parser should safely ignore invalid combined block
|
||||
content = """```json
|
||||
{"action": "a1", "input": "x"}
|
||||
{"action": "a2", "input": "y"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# No valid parsed action expected due to invalid combined JSON
|
||||
assert mock_action_class.call_count == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
|
||||
content = """```json
|
||||
{invalid}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result
|
||||
|
||||
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('```json {"a":1}')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('```json {"a":1}' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Action / Thought prefix handling
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
" action: something",
|
||||
" ACTION: something",
|
||||
" thought: reasoning",
|
||||
" THOUGHT: reasoning",
|
||||
],
|
||||
)
|
||||
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
joined = "".join(str(item) for item in result)
|
||||
expected_word = "something" if "action:" in content.lower() else "reasoning"
|
||||
assert expected_word in joined
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("xaction: test")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "x" in "".join(map(str, result))
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Mixed streaming scenarios
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = 'start {"action": "mix", "input": "1"} end'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# JSON action should be parsed
|
||||
mock_action_class.assert_called_once()
|
||||
# Ensure surrounding text is streamed (character-level)
|
||||
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
|
||||
assert "start" in joined
|
||||
assert "end" in joined
|
||||
|
||||
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert mock_action_class.call_count == 2
|
||||
|
||||
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("text with ` random ` backticks")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "text with" in "".join(result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Boundary & edge inputs
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
"```",
|
||||
"{",
|
||||
"}",
|
||||
"```json",
|
||||
"action:",
|
||||
"thought:",
|
||||
" ",
|
||||
],
|
||||
)
|
||||
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
joined = "".join(result)
|
||||
if content == " ":
|
||||
assert result == [] or joined == content
|
||||
if content in {"```", "{", "}", "```json"}:
|
||||
assert content in joined
|
||||
if content.lower() in {"action:", "thought:"}:
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
@ -0,0 +1,174 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
|
||||
|
||||
class DummyStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Concrete implementation for testing BaseAgentStrategy
|
||||
"""
|
||||
|
||||
def __init__(self, return_values=None, raise_exception=None):
|
||||
self.return_values = return_values or []
|
||||
self.raise_exception = raise_exception
|
||||
self.received_args = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id=None,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
credentials=None,
|
||||
) -> Generator:
|
||||
self.received_args = (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
credentials,
|
||||
)
|
||||
|
||||
if self.raise_exception:
|
||||
raise self.raise_exception
|
||||
|
||||
yield from self.return_values
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInstantiation:
|
||||
def test_cannot_instantiate_abstract_class(self) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
BaseAgentStrategy()
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInvoke:
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
return MagicMock(name="AgentInvokeMessage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
return MagicMock(name="InvokeCredentials")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("params", "user_id", "conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
|
||||
({}, "user2", None, None, None),
|
||||
({"a": 1}, "", "", "", ""),
|
||||
({"nested": {"x": 1}}, "user3", None, "app3", None),
|
||||
],
|
||||
)
|
||||
def test_invoke_success(
|
||||
self,
|
||||
mock_message,
|
||||
mock_credentials,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[mock_message])
|
||||
|
||||
# Act
|
||||
result = list(
|
||||
strategy.invoke(
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
credentials=mock_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_message]
|
||||
assert strategy.received_args == (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
mock_credentials,
|
||||
)
|
||||
|
||||
def test_invoke_multiple_yields(self, mock_message) -> None:
|
||||
# Arrange
|
||||
messages = [mock_message, MagicMock(), MagicMock()]
|
||||
strategy = DummyStrategy(return_values=messages)
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == messages
|
||||
|
||||
def test_invoke_empty_generator(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_invoke_propagates_exception(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(raise_exception=ValueError("failure"))
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="failure"):
|
||||
list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_params",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
123,
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
|
||||
"""
|
||||
Base class does not validate types — ensure pass-through behavior
|
||||
"""
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params=invalid_params, user_id="user"))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_invoke_none_user_id(self) -> None:
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params={}, user_id=None))
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestBaseAgentStrategyGetParameters:
|
||||
def test_get_parameters_default_empty_list(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
result = strategy.get_parameters()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == []
|
||||
|
||||
def test_get_parameters_returns_new_list_each_time(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
|
||||
first = strategy.get_parameters()
|
||||
second = strategy.get_parameters()
|
||||
|
||||
assert first == second == []
|
||||
assert first is not second
|
||||
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
@ -0,0 +1,272 @@
|
||||
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
|
||||
# ============================================================
|
||||
# Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parameter():
|
||||
def _factory(name="param", return_value="initialized"):
|
||||
param = MagicMock()
|
||||
param.name = name
|
||||
param.init_frontend_parameter = MagicMock(return_value=return_value)
|
||||
return param
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_declaration(mock_parameter):
|
||||
param1 = mock_parameter("param1", "init1")
|
||||
param2 = mock_parameter("param2", "init2")
|
||||
|
||||
identity = MagicMock()
|
||||
identity.provider = "provider_x"
|
||||
identity.name = "strategy_x"
|
||||
|
||||
declaration = MagicMock()
|
||||
declaration.parameters = [param1, param2]
|
||||
declaration.identity = identity
|
||||
|
||||
return declaration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(mock_declaration):
|
||||
return PluginAgentStrategy(
|
||||
tenant_id="tenant_123",
|
||||
declaration=mock_declaration,
|
||||
meta_version="v1",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Initialization Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestPluginAgentStrategyInitialization:
|
||||
def test_init_sets_attributes(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version="meta_v",
|
||||
)
|
||||
|
||||
assert strategy.tenant_id == "tenant_test"
|
||||
assert strategy.declaration == mock_declaration
|
||||
assert strategy.meta_version == "meta_v"
|
||||
|
||||
def test_init_meta_version_none(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version=None,
|
||||
)
|
||||
|
||||
assert strategy.meta_version is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetParameters:
|
||||
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
|
||||
result = strategy.get_parameters()
|
||||
assert result == mock_declaration.parameters
|
||||
|
||||
|
||||
# ============================================================
|
||||
# initialize_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInitializeParameters:
|
||||
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
|
||||
params = {"param1": "value1"}
|
||||
|
||||
result = strategy.initialize_parameters(params.copy())
|
||||
|
||||
assert result["param1"] == "init1"
|
||||
assert result["param2"] == "init2"
|
||||
|
||||
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
|
||||
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_params",
|
||||
[
|
||||
{},
|
||||
{"param1": None},
|
||||
{"param1": ""},
|
||||
{"param1": 0},
|
||||
{"param1": []},
|
||||
{"param1": {}, "param2": "value"},
|
||||
],
|
||||
)
|
||||
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
|
||||
result = strategy.initialize_parameters(input_params.copy())
|
||||
|
||||
for param in strategy.declaration.parameters:
|
||||
assert param.name in result
|
||||
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.initialize_parameters(None)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _invoke Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInvoke:
|
||||
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mock_convert = mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={"converted": True},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={"param1": "value"},
|
||||
user_id="user_1",
|
||||
conversation_id="conv_1",
|
||||
app_id="app_1",
|
||||
message_id="msg_1",
|
||||
credentials=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["msg1", "msg2"]
|
||||
mock_convert.assert_called_once()
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
call_kwargs = mock_manager.invoke.call_args.kwargs
|
||||
assert call_kwargs["tenant_id"] == "tenant_123"
|
||||
assert call_kwargs["user_id"] == "user_1"
|
||||
assert call_kwargs["agent_provider"] == "provider_x"
|
||||
assert call_kwargs["agent_strategy"] == "strategy_x"
|
||||
assert call_kwargs["agent_params"] == {"converted": True}
|
||||
assert call_kwargs["conversation_id"] == "conv_1"
|
||||
assert call_kwargs["app_id"] == "app_1"
|
||||
assert call_kwargs["message_id"] == "msg_1"
|
||||
assert call_kwargs["context"] is not None
|
||||
|
||||
def test_invoke_with_credentials(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
# Patch PluginInvokeContext to bypass pydantic validation
|
||||
mock_context = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginInvokeContext",
|
||||
return_value=mock_context,
|
||||
)
|
||||
|
||||
credentials = MagicMock()
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
(None, None, None),
|
||||
("conv", None, None),
|
||||
(None, "app", None),
|
||||
(None, None, "msg"),
|
||||
],
|
||||
)
|
||||
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
side_effect=ValueError("conversion failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
|
||||
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
@ -0,0 +1,802 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.agent.base_agent_runner as module
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
|
||||
# ==========================================================
|
||||
# Fixtures
|
||||
# ==========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mocker):
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, mock_db_session):
|
||||
r = BaseAgentRunner.__new__(BaseAgentRunner)
|
||||
r.tenant_id = "tenant"
|
||||
r.user_id = "user"
|
||||
r.agent_thought_count = 0
|
||||
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
|
||||
r.app_config = mocker.MagicMock()
|
||||
r.app_config.app_id = "app1"
|
||||
r.app_config.agent = None
|
||||
r.dataset_tools = []
|
||||
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
|
||||
r._current_thoughts = []
|
||||
return r
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _repack_app_generate_entity
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == "abc"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# update_prompt_message_tool
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def build_param(self, mocker, **kwargs):
|
||||
p = mocker.MagicMock()
|
||||
p.form = kwargs.get("form")
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
p.type = mock_type
|
||||
|
||||
p.name = kwargs.get("name", "p1")
|
||||
p.llm_description = "desc"
|
||||
p.input_schema = kwargs.get("input_schema")
|
||||
p.options = kwargs.get("options")
|
||||
p.required = kwargs.get("required", False)
|
||||
return p
|
||||
|
||||
def test_skip_non_llm(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form="NOT_LLM")
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_enum_and_required(self, runner, mocker):
|
||||
option = mocker.MagicMock(value="opt1")
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
options=[option],
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool = mocker.MagicMock()
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert "p1" in result.parameters["required"]
|
||||
|
||||
def test_skip_file_type_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
|
||||
param.type = module.ToolParameter.ToolParameterType.FILE
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["required"].count("p1") == 1
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# create_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
|
||||
assert result == "11"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# save_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestSaveAgentThought:
|
||||
def setup_agent(self, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;tool2"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
return agent
|
||||
|
||||
def test_not_found(self, runner, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mock_label = mocker.MagicMock()
|
||||
mock_label.to_dict.return_value = {"en_US": "label"}
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
|
||||
|
||||
usage = mocker.MagicMock(
|
||||
prompt_tokens=1,
|
||||
prompt_price_unit=Decimal("0.1"),
|
||||
prompt_unit_price=Decimal("0.1"),
|
||||
completion_tokens=2,
|
||||
completion_price_unit=Decimal("0.2"),
|
||||
completion_unit_price=Decimal("0.2"),
|
||||
total_tokens=3,
|
||||
total_price=Decimal("0.3"),
|
||||
)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1;tool2",
|
||||
{"a": 1},
|
||||
"thought",
|
||||
{"b": 2},
|
||||
{"meta": 1},
|
||||
"answer",
|
||||
["f1"],
|
||||
usage,
|
||||
)
|
||||
|
||||
assert agent.answer == "answer"
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
bad_obj = MagicMock()
|
||||
bad_obj.__str__.return_value = "bad"
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
bad_obj,
|
||||
None,
|
||||
bad_obj,
|
||||
bad_obj,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
{"a": 1},
|
||||
None,
|
||||
{"b": 2},
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_user_prompt
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_history
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
observation="invalid",
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
observation=json.dumps({"tool1": "obs"}),
|
||||
thought="thinking",
|
||||
)
|
||||
|
||||
msg = mocker.MagicMock(
|
||||
id="m100",
|
||||
agent_thoughts=[thought],
|
||||
answer=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _convert_tool_to_prompt_message_tool (new coverage)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
runtime_param = mocker.MagicMock()
|
||||
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param.name = "param1"
|
||||
runtime_param.llm_description = "desc"
|
||||
runtime_param.required = True
|
||||
runtime_param.input_schema = None
|
||||
runtime_param.options = None
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
runtime_param.type = mock_type
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert entity == tool_entity
|
||||
|
||||
def test_full_conversion_multiple_params(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
# LLM param with input_schema override
|
||||
param1 = mocker.MagicMock()
|
||||
param1.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param1.name = "p1"
|
||||
param1.llm_description = "desc"
|
||||
param1.required = True
|
||||
param1.input_schema = {"type": "integer"}
|
||||
param1.options = None
|
||||
param1.type = mocker.MagicMock()
|
||||
|
||||
# SYSTEM_FILES param should be skipped
|
||||
param2 = mocker.MagicMock()
|
||||
param2.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param2.name = "file_param"
|
||||
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert entity == tool_entity
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _init_prompt_tools additional branches
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert tools == {}
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "p1"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.options = None
|
||||
param.input_schema = {"type": "number"}
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
param.type = mock_type
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"]["p1"]["type"] == "number"
|
||||
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
|
||||
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
|
||||
|
||||
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
|
||||
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
# ================= Additional Surgical Coverage =================
|
||||
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = True
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert prompt_tool is not None
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
|
||||
assert prompt is not None
|
||||
|
||||
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
|
||||
|
||||
llm = mocker.MagicMock()
|
||||
llm.get_model_schema.return_value = mocker.MagicMock(
|
||||
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
|
||||
)
|
||||
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
|
||||
|
||||
app_config = mocker.MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.agent = None
|
||||
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
|
||||
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
|
||||
|
||||
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
|
||||
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
|
||||
|
||||
runner = BaseAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=app_generate,
|
||||
conversation=mocker.MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=mocker.MagicMock(),
|
||||
config=mocker.MagicMock(),
|
||||
queue_manager=mocker.MagicMock(),
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert runner.stream_tool_call is True
|
||||
assert runner.files == ["file1"]
|
||||
assert runner.dataset_tools == ["ds_tool"]
|
||||
assert runner.agent_thought_count == 2
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = "NOT_LLM"
|
||||
param.type = mocker.MagicMock()
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert prompt_tool.parameters["properties"] == {}
|
||||
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
|
||||
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
|
||||
|
||||
tools, prompt_tools = runner._init_prompt_tools()
|
||||
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
tool_input = {"a": 1}
|
||||
observation = {"b": 2}
|
||||
tool_meta = {"c": 3}
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def dumps_side_effect(value, *args, **kwargs):
|
||||
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
|
||||
raise TypeError("fail")
|
||||
return real_dumps(value, *args, **kwargs)
|
||||
|
||||
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1",
|
||||
tool_input,
|
||||
None,
|
||||
observation,
|
||||
tool_meta,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
system_message = module.SystemPromptMessage(content="sys")
|
||||
|
||||
result = runner.organize_agent_history([system_message])
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
observation=None,
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"organize_agent_user_prompt",
|
||||
return_value=module.UserPromptMessage(content="user"),
|
||||
)
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
|
||||
assert any(isinstance(item, module.ToolPromptMessage) for item in result)
|
||||
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
@ -0,0 +1,551 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
@ -0,0 +1,215 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@ -0,0 +1,234 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
@ -0,0 +1,452 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""Unit tests for core.agent.plugin_entities.
|
||||
|
||||
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
|
||||
Pydantic ValidationError behavior and pytest fixtures for validation and
|
||||
mocking; ensure entity invariants and validation rules remain stable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agent.plugin_entities import (
|
||||
AgentFeature,
|
||||
AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity,
|
||||
AgentStrategyIdentity,
|
||||
AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity,
|
||||
AgentStrategyProviderIdentity,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
|
||||
|
||||
# =========================================================
|
||||
# Fixtures
|
||||
# =========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyIdentity)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameterType Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameterType:
|
||||
@pytest.mark.parametrize(
|
||||
"enum_member",
|
||||
list(AgentStrategyParameter.AgentStrategyParameterType),
|
||||
)
|
||||
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
return_value="normalized",
|
||||
)
|
||||
|
||||
result = enum_member.as_normal_type()
|
||||
|
||||
mock_func.assert_called_once_with(enum_member)
|
||||
assert result == "normalized"
|
||||
|
||||
def test_as_normal_type_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
enum_member.as_normal_type()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enum_member", "value"),
|
||||
[
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
|
||||
],
|
||||
)
|
||||
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
return_value="casted",
|
||||
)
|
||||
|
||||
result = enum_member.cast_value(value)
|
||||
|
||||
mock_func.assert_called_once_with(enum_member, value)
|
||||
assert result == "casted"
|
||||
|
||||
def test_cast_value_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
side_effect=ValueError("invalid"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
enum_member.cast_value("bad")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameter Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameter:
|
||||
def test_valid_creation_minimal(self) -> None:
|
||||
# bypass base PluginParameter required fields using model_construct
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=None,
|
||||
)
|
||||
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
assert param.help is None
|
||||
|
||||
def test_valid_creation_with_help(self) -> None:
|
||||
help_obj = I18nObject(en_US="test")
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=help_obj,
|
||||
)
|
||||
assert param.help == help_obj
|
||||
|
||||
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
|
||||
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
|
||||
|
||||
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
|
||||
|
||||
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
return_value="frontend",
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
result = param.init_frontend_parameter("value")
|
||||
|
||||
mock_func.assert_called_once_with(param, param.type, "value")
|
||||
assert result == "frontend"
|
||||
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
side_effect=RuntimeError("error"),
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
param.init_frontend_parameter("value")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyProviderEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyProviderEntity:
|
||||
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="plugin-123",
|
||||
)
|
||||
assert entity.plugin_id == "plugin-123"
|
||||
|
||||
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="",
|
||||
)
|
||||
assert entity.plugin_id == ""
|
||||
|
||||
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
|
||||
assert entity.plugin_id is None
|
||||
|
||||
def test_invalid_identity_raises(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyProviderEntity(identity="invalid")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyEntity:
|
||||
def test_parameters_default_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=None,
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_preserved(self, mock_identity) -> None:
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[param],
|
||||
)
|
||||
assert entity.parameters == [param]
|
||||
|
||||
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters="invalid",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"features",
|
||||
[
|
||||
None,
|
||||
[],
|
||||
[AgentFeature.HISTORY_MESSAGES],
|
||||
],
|
||||
)
|
||||
def test_features_valid(self, mock_identity, features) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features=features,
|
||||
)
|
||||
assert entity.features == features
|
||||
|
||||
def test_invalid_features_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features="invalid",
|
||||
)
|
||||
|
||||
def test_output_schema_and_meta_version(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
output_schema={"type": "object"},
|
||||
meta_version="v1",
|
||||
)
|
||||
assert entity.output_schema == {"type": "object"}
|
||||
assert entity.meta_version == "v1"
|
||||
|
||||
def test_missing_required_fields_raise(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(identity=mock_identity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentProviderEntityWithPlugin Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentProviderEntityWithPlugin:
|
||||
def test_default_strategies_empty(self, mock_provider_identity) -> None:
|
||||
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
|
||||
assert entity.strategies == []
|
||||
|
||||
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
|
||||
strategy = AgentStrategyEntity.model_construct(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
entity = AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies=[strategy],
|
||||
)
|
||||
assert entity.strategies == [strategy]
|
||||
|
||||
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies="invalid",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# Inheritance Smoke Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestInheritanceBehavior:
|
||||
def test_agent_strategy_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyIdentity, ToolIdentity)
|
||||
|
||||
def test_agent_strategy_provider_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)
|
||||
Loading…
Reference in New Issue
Block a user