mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
test: unit test cases for sub modules in core.app (except core.app.apps) (#32476)
This commit is contained in:
parent
e873cea99e
commit
36cc1bf025
@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
self._paused = True
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
@ -68,7 +68,7 @@ def init_tool_node(config: dict):
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
def test_tool_variable_invoke(monkeypatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
@ -103,7 +103,7 @@ def test_tool_variable_invoke():
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
def test_tool_mixed_invoke(monkeypatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
|
||||
@ -0,0 +1,227 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Module under test
|
||||
from core.app.app_config.common import parameters_mapping
|
||||
|
||||
|
||||
class TestGetParametersFromFeatureDict:
|
||||
"""Test suite for get_parameters_from_feature_dict"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self, monkeypatch):
|
||||
"""Mock dify_config values"""
|
||||
mock = MagicMock()
|
||||
mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1
|
||||
mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2
|
||||
mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3
|
||||
mock.UPLOAD_FILE_SIZE_LIMIT = 4
|
||||
mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5
|
||||
|
||||
monkeypatch.setattr(parameters_mapping, "dify_config", mock)
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_default_file_limits(self, monkeypatch):
|
||||
"""Mock DEFAULT_FILE_NUMBER_LIMITS constant"""
|
||||
monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99)
|
||||
return 99
|
||||
|
||||
@pytest.fixture
|
||||
def minimal_inputs(self):
|
||||
return {}, []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("feature_key", "expected_default"),
|
||||
[
|
||||
("suggested_questions", []),
|
||||
("suggested_questions_after_answer", {"enabled": False}),
|
||||
("speech_to_text", {"enabled": False}),
|
||||
("text_to_speech", {"enabled": False}),
|
||||
("retriever_resource", {"enabled": False}),
|
||||
("annotation_reply", {"enabled": False}),
|
||||
("more_like_this", {"enabled": False}),
|
||||
(
|
||||
"sensitive_word_avoidance",
|
||||
{"enabled": False, "type": "", "configs": []},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_defaults_when_key_missing(
|
||||
self,
|
||||
feature_key,
|
||||
expected_default,
|
||||
mock_config,
|
||||
mock_default_file_limits,
|
||||
):
|
||||
# Arrange
|
||||
features = {}
|
||||
user_input = []
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=user_input,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result[feature_key] == expected_default
|
||||
|
||||
def test_opening_statement_present(self, mock_config, mock_default_file_limits):
|
||||
# Arrange
|
||||
features = {"opening_statement": "Hello"}
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["opening_statement"] == "Hello"
|
||||
|
||||
def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits):
|
||||
# Arrange
|
||||
features = {}
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["opening_statement"] is None
|
||||
|
||||
def test_all_features_provided(self, mock_config, mock_default_file_limits):
|
||||
# Arrange
|
||||
features = {
|
||||
"opening_statement": "Hi",
|
||||
"suggested_questions": ["Q1"],
|
||||
"suggested_questions_after_answer": {"enabled": True},
|
||||
"speech_to_text": {"enabled": True},
|
||||
"text_to_speech": {"enabled": True},
|
||||
"retriever_resource": {"enabled": True},
|
||||
"annotation_reply": {"enabled": True},
|
||||
"more_like_this": {"enabled": True},
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "strict",
|
||||
"configs": ["a"],
|
||||
},
|
||||
"file_upload": {
|
||||
"image": {
|
||||
"enabled": True,
|
||||
"number_limits": 10,
|
||||
"detail": "low",
|
||||
"transfer_methods": ["local_file"],
|
||||
}
|
||||
},
|
||||
}
|
||||
user_input = [{"name": "field1"}]
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=user_input,
|
||||
)
|
||||
|
||||
# Assert
|
||||
for key in features:
|
||||
assert result[key] == features[key]
|
||||
assert result["user_input_form"] == user_input
|
||||
|
||||
def test_file_upload_default_structure(self, mock_config, mock_default_file_limits):
|
||||
# Arrange
|
||||
features = {}
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
file_upload = result["file_upload"]
|
||||
assert file_upload["image"]["enabled"] is False
|
||||
assert file_upload["image"]["number_limits"] == 99
|
||||
assert file_upload["image"]["detail"] == "high"
|
||||
assert "remote_url" in file_upload["image"]["transfer_methods"]
|
||||
assert "local_file" in file_upload["image"]["transfer_methods"]
|
||||
|
||||
def test_system_parameters_from_config(self, mock_config, mock_default_file_limits):
|
||||
# Arrange
|
||||
features = {}
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
system_params = result["system_parameters"]
|
||||
assert system_params["image_file_size_limit"] == 1
|
||||
assert system_params["video_file_size_limit"] == 2
|
||||
assert system_params["audio_file_size_limit"] == 3
|
||||
assert system_params["file_size_limit"] == 4
|
||||
assert system_params["workflow_file_upload_limit"] == 5
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("features_dict", "user_input_form"),
|
||||
[
|
||||
(None, []),
|
||||
([], []),
|
||||
("invalid", []),
|
||||
],
|
||||
)
|
||||
def test_invalid_features_dict_type_raises(self, features_dict, user_input_form):
|
||||
# Act & Assert
|
||||
with pytest.raises(AttributeError):
|
||||
parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict,
|
||||
user_input_form=user_input_form,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_input_form",
|
||||
[None, "invalid", 123],
|
||||
)
|
||||
def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form):
|
||||
# Arrange
|
||||
features = {}
|
||||
|
||||
# Act
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=user_input_form,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["user_input_form"] == user_input_form
|
||||
|
||||
def test_empty_user_input_form(self, mock_config, mock_default_file_limits):
|
||||
features = {}
|
||||
user_input = []
|
||||
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=user_input,
|
||||
)
|
||||
|
||||
assert result["user_input_form"] == []
|
||||
|
||||
def test_feature_values_none(self, mock_config, mock_default_file_limits):
|
||||
features = {
|
||||
"suggested_questions": None,
|
||||
"speech_to_text": None,
|
||||
}
|
||||
|
||||
result = parameters_mapping.get_parameters_from_feature_dict(
|
||||
features_dict=features,
|
||||
user_input_form=[],
|
||||
)
|
||||
|
||||
assert result["suggested_questions"] is None
|
||||
assert result["speech_to_text"] is None
|
||||
@ -0,0 +1,202 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import (
|
||||
SensitiveWordAvoidanceConfigManager,
|
||||
)
|
||||
|
||||
|
||||
class TestSensitiveWordAvoidanceConfigManagerConvert:
|
||||
"""Tests for convert classmethod"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{},
|
||||
{"sensitive_word_avoidance": None},
|
||||
{"sensitive_word_avoidance": {}},
|
||||
{"sensitive_word_avoidance": {"enabled": False}},
|
||||
],
|
||||
)
|
||||
def test_convert_returns_none_when_disabled_or_missing(self, config):
|
||||
# Act
|
||||
result = SensitiveWordAvoidanceConfigManager.convert(config)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_convert_returns_entity_when_enabled(self, mocker):
|
||||
# Arrange
|
||||
mock_entity = MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity",
|
||||
return_value=mock_entity,
|
||||
)
|
||||
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "mock_type",
|
||||
"config": {"key": "value"},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
result = SensitiveWordAvoidanceConfigManager.convert(config)
|
||||
|
||||
# Assert
|
||||
assert result == mock_entity
|
||||
|
||||
def test_convert_enabled_without_type_or_config(self, mocker):
|
||||
# Arrange
|
||||
mock_entity = MagicMock()
|
||||
patched = mocker.patch(
|
||||
"core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity",
|
||||
return_value=mock_entity,
|
||||
)
|
||||
|
||||
config = {"sensitive_word_avoidance": {"enabled": True}}
|
||||
|
||||
# Act
|
||||
result = SensitiveWordAvoidanceConfigManager.convert(config)
|
||||
|
||||
# Assert
|
||||
patched.assert_called_once_with(type=None, config={})
|
||||
assert result == mock_entity
|
||||
|
||||
|
||||
class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults:
|
||||
"""Tests for validate_and_set_defaults classmethod"""
|
||||
|
||||
@pytest.fixture
|
||||
def base_config(self):
|
||||
return {}
|
||||
|
||||
def test_validate_sets_default_when_missing(self, base_config):
|
||||
# Act
|
||||
config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id="tenant1", config=base_config.copy()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert config["sensitive_word_avoidance"]["enabled"] is False
|
||||
assert fields == ["sensitive_word_avoidance"]
|
||||
|
||||
def test_validate_raises_when_not_dict(self):
|
||||
config = {"sensitive_word_avoidance": "invalid"}
|
||||
|
||||
with pytest.raises(ValueError, match="must be of dict type"):
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"sensitive_word_avoidance": {"enabled": False}},
|
||||
{"sensitive_word_avoidance": {"enabled": None}},
|
||||
{"sensitive_word_avoidance": {}},
|
||||
],
|
||||
)
|
||||
def test_validate_disables_when_enabled_false_or_missing(self, config):
|
||||
# Act
|
||||
result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id="tenant1", config=config
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result_config["sensitive_word_avoidance"]["enabled"] is False
|
||||
|
||||
def test_validate_raises_when_enabled_true_without_type(self):
|
||||
config = {"sensitive_word_avoidance": {"enabled": True}}
|
||||
|
||||
with pytest.raises(ValueError, match="type is required"):
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
|
||||
|
||||
def test_validate_raises_when_type_not_string(self):
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": 123,
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
|
||||
|
||||
def test_validate_raises_when_config_not_dict(self):
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "mock_type",
|
||||
"config": "invalid",
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must be a dict"):
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
|
||||
|
||||
def test_validate_calls_moderation_factory(self, mocker):
|
||||
# Arrange
|
||||
mock_validate = mocker.patch(
|
||||
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
|
||||
)
|
||||
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "mock_type",
|
||||
"config": {"k": "v"},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id="tenant1", config=config
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"})
|
||||
assert result_config["sensitive_word_avoidance"]["enabled"] is True
|
||||
assert fields == ["sensitive_word_avoidance"]
|
||||
|
||||
def test_validate_sets_empty_dict_when_config_none(self, mocker):
|
||||
# Arrange
|
||||
mock_validate = mocker.patch(
|
||||
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
|
||||
)
|
||||
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "mock_type",
|
||||
"config": None,
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
|
||||
|
||||
# Assert
|
||||
mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={})
|
||||
|
||||
def test_validate_only_structure_validate_skips_factory(self, mocker):
|
||||
# Arrange
|
||||
mock_validate = mocker.patch(
|
||||
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
|
||||
)
|
||||
|
||||
config = {
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": True,
|
||||
"type": "mock_type",
|
||||
"config": {"k": "v"},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id="tenant1", config=config, only_structure_validate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_validate.assert_not_called()
|
||||
@ -0,0 +1,236 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
|
||||
|
||||
class TestAgentConfigManagerConvert:
|
||||
@pytest.fixture
|
||||
def base_config(self):
|
||||
return {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": "cot",
|
||||
"tools": [],
|
||||
},
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "completion",
|
||||
},
|
||||
}
|
||||
|
||||
def test_convert_returns_none_when_agent_mode_missing(self):
|
||||
config = {"model": {"provider": "openai", "name": "gpt-4"}}
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}])
|
||||
def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = agent_mode_value
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("strategy_input", "expected_enum"),
|
||||
[
|
||||
("function_call", "FUNCTION_CALLING"),
|
||||
("cot", "CHAIN_OF_THOUGHT"),
|
||||
("react", "CHAIN_OF_THOUGHT"),
|
||||
],
|
||||
)
|
||||
def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": strategy_input,
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is not None
|
||||
assert result.strategy.name == expected_enum
|
||||
|
||||
def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "unknown_strategy",
|
||||
"tools": [],
|
||||
}
|
||||
config["model"]["provider"] = "openai"
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result.strategy.name == "FUNCTION_CALLING"
|
||||
|
||||
def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "unknown_strategy",
|
||||
"tools": [],
|
||||
}
|
||||
config["model"]["provider"] = "anthropic"
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result.strategy.name == "CHAIN_OF_THOUGHT"
|
||||
|
||||
def test_convert_skips_disabled_tools(self, mocker, base_config):
|
||||
# Patch AgentEntity to bypass pydantic validation
|
||||
mock_agent_entity = mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mock_validate = mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate",
|
||||
return_value={
|
||||
"provider_type": "type2",
|
||||
"provider_id": "id2",
|
||||
"tool_name": "tool2",
|
||||
"tool_parameters": {},
|
||||
"credential_id": None,
|
||||
},
|
||||
)
|
||||
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "cot",
|
||||
"tools": [
|
||||
{
|
||||
"provider_type": "type1",
|
||||
"provider_id": "id1",
|
||||
"tool_name": "tool1",
|
||||
"enabled": False,
|
||||
},
|
||||
{
|
||||
"provider_type": "type2",
|
||||
"provider_id": "id2",
|
||||
"tool_name": "tool2",
|
||||
"enabled": True,
|
||||
"extra_key": "x",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
AgentConfigManager.convert(config)
|
||||
|
||||
mock_validate.assert_called_once()
|
||||
mock_agent_entity.assert_called_once()
|
||||
|
||||
def test_convert_tool_requires_minimum_keys(self, mocker, base_config):
|
||||
mock_validate = mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "cot",
|
||||
"tools": [
|
||||
{"a": 1, "b": 2}, # insufficient keys
|
||||
],
|
||||
}
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is not None
|
||||
assert result.tools == []
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
def test_convert_completion_mode_prompt_defaults(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"]["prompt"] = {}
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is not None
|
||||
assert result.prompt.first_prompt is not None
|
||||
assert result.prompt.next_iteration is not None
|
||||
|
||||
def test_convert_chat_mode_prompt_defaults(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"]["prompt"] = {}
|
||||
config["model"]["mode"] = "chat"
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is not None
|
||||
assert result.prompt.first_prompt is not None
|
||||
assert result.prompt.next_iteration is not None
|
||||
|
||||
def test_convert_router_strategy_returns_none(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "router",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_convert_react_router_strategy_returns_none(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"] = {
|
||||
"enabled": True,
|
||||
"strategy": "react_router",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_convert_max_iteration_default(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"].pop("max_iteration", None)
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result.max_iteration == 10
|
||||
|
||||
def test_convert_custom_max_iteration(self, base_config):
|
||||
config = base_config.copy()
|
||||
config["agent_mode"]["max_iteration"] = 25
|
||||
|
||||
result = AgentConfigManager.convert(config)
|
||||
|
||||
assert result.max_iteration == 25
|
||||
|
||||
def test_convert_missing_model_raises_key_error(self, base_config):
|
||||
config = base_config.copy()
|
||||
del config["model"]
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
AgentConfigManager.convert(config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invalid_config", "should_raise"),
|
||||
[
|
||||
(None, True),
|
||||
(123, True),
|
||||
("", False),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise):
|
||||
if should_raise:
|
||||
with pytest.raises(TypeError):
|
||||
AgentConfigManager.convert(invalid_config) # type: ignore
|
||||
else:
|
||||
result = AgentConfigManager.convert(invalid_config) # type: ignore
|
||||
assert result is None
|
||||
@ -0,0 +1,319 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config(valid_uuid):
|
||||
return {
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "multiple",
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service(mocker, valid_uuid):
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.tenant_id = "tenant1"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
|
||||
return_value=mock_dataset,
|
||||
)
|
||||
|
||||
|
||||
# ==============================
|
||||
# convert tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestDatasetConfigManagerConvert:
|
||||
def test_convert_returns_none_when_no_datasets(self):
|
||||
config = {"dataset_configs": {"datasets": {"datasets": []}}}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result is None
|
||||
|
||||
def test_convert_single_retrieval(self, valid_uuid):
|
||||
config = {
|
||||
"dataset_query_variable": "query",
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "single",
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result is not None
|
||||
assert result.dataset_ids == [valid_uuid]
|
||||
assert result.retrieve_config.query_variable == "query"
|
||||
|
||||
def test_convert_single_with_metadata_configs(self, valid_uuid, mocker):
|
||||
mock_retrieve_config = MagicMock()
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.dataset_ids = [valid_uuid]
|
||||
mock_entity.retrieve_config = mock_retrieve_config
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig",
|
||||
return_value={"mock": "model"},
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition",
|
||||
return_value={"mock": "condition"},
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity",
|
||||
return_value=mock_retrieve_config,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity",
|
||||
return_value=mock_entity,
|
||||
)
|
||||
|
||||
config = {
|
||||
"dataset_query_variable": "query",
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "single",
|
||||
"metadata_filtering_mode": "manual",
|
||||
"metadata_model_config": {"any": "value"},
|
||||
"metadata_filtering_conditions": {"any": "value"},
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
},
|
||||
},
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result.dataset_ids == [valid_uuid]
|
||||
assert result.retrieve_config is mock_retrieve_config
|
||||
|
||||
def test_convert_multiple_defaults(self, valid_uuid):
|
||||
config = {
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "multiple",
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
},
|
||||
}
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result.retrieve_config.top_k == 4
|
||||
assert result.retrieve_config.score_threshold is None
|
||||
assert result.retrieve_config.reranking_enabled is True
|
||||
|
||||
def test_convert_agent_mode_disabled_tool(self, valid_uuid):
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"tools": [{"dataset": {"id": valid_uuid, "enabled": False}}],
|
||||
}
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result is None
|
||||
|
||||
def test_convert_dataset_configs_none(self):
|
||||
config = {"dataset_configs": None}
|
||||
with pytest.raises(TypeError):
|
||||
DatasetConfigManager.convert(config)
|
||||
|
||||
def test_convert_agent_mode_old_style_old_format(self, valid_uuid):
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
}
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result.dataset_ids == [valid_uuid]
|
||||
assert result.retrieve_config.query_variable is None
|
||||
|
||||
def test_convert_multiple_with_score_threshold(self, valid_uuid):
|
||||
config = {
|
||||
"dataset_query_variable": "query",
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "multiple",
|
||||
"top_k": 5,
|
||||
"score_threshold": 0.8,
|
||||
"score_threshold_enabled": True,
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result.retrieve_config.top_k == 5
|
||||
assert result.retrieve_config.score_threshold == 0.8
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_entry",
|
||||
[
|
||||
{},
|
||||
{"invalid": {}},
|
||||
{"dataset": {"id": None, "enabled": True}},
|
||||
{"dataset": {"id": "", "enabled": False}},
|
||||
],
|
||||
)
|
||||
def test_convert_ignores_invalid_dataset_entries(self, dataset_entry):
|
||||
config = {
|
||||
"dataset_configs": {
|
||||
"retrieval_model": "multiple",
|
||||
"datasets": {"strategy": "router", "datasets": [dataset_entry]},
|
||||
}
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result is None
|
||||
|
||||
def test_convert_agent_mode_old_style(self, valid_uuid):
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
}
|
||||
}
|
||||
result = DatasetConfigManager.convert(config)
|
||||
assert result.dataset_ids == [valid_uuid]
|
||||
|
||||
|
||||
# ==============================
|
||||
# validate_and_set_defaults tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestValidateAndSetDefaults:
|
||||
def test_validate_sets_defaults(self):
|
||||
config = {}
|
||||
updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config)
|
||||
assert "dataset_configs" in updated
|
||||
assert updated["dataset_configs"]["retrieval_model"] == "single"
|
||||
assert isinstance(fields, list)
|
||||
|
||||
def test_validate_raises_when_dataset_configs_not_dict(self):
|
||||
config = {"dataset_configs": "invalid"}
|
||||
with pytest.raises(AttributeError):
|
||||
DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config)
|
||||
|
||||
def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid):
|
||||
config = {
|
||||
"dataset_configs": {
|
||||
"datasets": {
|
||||
"strategy": "router",
|
||||
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
}
|
||||
}
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config)
|
||||
|
||||
|
||||
# ==============================
|
||||
# extract_dataset_config_for_legacy_compatibility tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractDatasetConfig:
|
||||
def test_extract_sets_defaults(self):
|
||||
config = {}
|
||||
result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
assert "agent_mode" in result
|
||||
assert result["agent_mode"]["enabled"] is False
|
||||
assert result["agent_mode"]["tools"] == []
|
||||
|
||||
def test_extract_invalid_agent_mode_type(self):
|
||||
config = {"agent_mode": "invalid"}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
|
||||
def test_extract_invalid_enabled_type(self):
|
||||
config = {"agent_mode": {"enabled": "yes"}}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
|
||||
def test_extract_invalid_tools_type(self):
|
||||
config = {"agent_mode": {"enabled": True, "tools": "invalid"}}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
|
||||
def test_extract_invalid_uuid(self, mocker):
|
||||
invalid_uuid = "not-a-uuid"
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER,
|
||||
"tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}],
|
||||
}
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
|
||||
def test_extract_dataset_not_exists(self, valid_uuid, mocker):
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
)
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER,
|
||||
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
|
||||
}
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
|
||||
|
||||
|
||||
# ==============================
|
||||
# is_dataset_exists tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestIsDatasetExists:
|
||||
def test_dataset_exists_true(self, mocker, valid_uuid):
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.tenant_id = "tenant1"
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
|
||||
return_value=mock_dataset,
|
||||
)
|
||||
|
||||
assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
|
||||
|
||||
def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid):
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
)
|
||||
assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
|
||||
|
||||
def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid):
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.tenant_id = "other"
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
|
||||
return_value=mock_dataset,
|
||||
)
|
||||
assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
|
||||
@ -0,0 +1,234 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
|
||||
|
||||
class TestModelConfigConverter:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_response_entity(self, mocker):
|
||||
"""
|
||||
Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation
|
||||
and return a simple namespace object instead.
|
||||
"""
|
||||
|
||||
def _factory(**kwargs):
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity",
|
||||
side_effect=_factory,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_config(self):
|
||||
app_config = MagicMock()
|
||||
app_config.tenant_id = "tenant_1"
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.provider = "openai"
|
||||
model_config.model = "gpt-4"
|
||||
model_config.parameters = {"temperature": 0.5}
|
||||
model_config.mode = None
|
||||
|
||||
app_config.model = model_config
|
||||
return app_config
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_bundle(self):
|
||||
bundle = MagicMock()
|
||||
|
||||
# configuration
|
||||
configuration = MagicMock()
|
||||
configuration.provider.provider = "openai"
|
||||
configuration.get_current_credentials.return_value = {"api_key": "key"}
|
||||
|
||||
provider_model = MagicMock()
|
||||
provider_model.status = ModelStatus.ACTIVE
|
||||
configuration.get_provider_model.return_value = provider_model
|
||||
|
||||
bundle.configuration = configuration
|
||||
|
||||
# model type instance
|
||||
model_type_instance = MagicMock()
|
||||
model_schema = MagicMock()
|
||||
model_schema.model_properties = {}
|
||||
model_type_instance.get_model_schema.return_value = model_schema
|
||||
bundle.model_type_instance = model_type_instance
|
||||
|
||||
return bundle
|
||||
|
||||
@pytest.fixture
|
||||
def patch_provider_manager(self, mocker, mock_provider_bundle):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
return mock_manager
|
||||
|
||||
# =============================
|
||||
# Positive Scenarios
|
||||
# =============================
|
||||
|
||||
def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager):
|
||||
result = ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
assert result.provider == "openai"
|
||||
assert result.model == "gpt-4"
|
||||
assert result.mode == LLMMode.CHAT
|
||||
assert result.parameters == {"temperature": 0.5}
|
||||
assert result.stop == []
|
||||
|
||||
def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager):
|
||||
mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]}
|
||||
|
||||
result = ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
assert result.parameters == {"temperature": 0.7}
|
||||
assert result.stop == ["\n"]
|
||||
|
||||
def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker):
|
||||
mock_app_config.model.mode = None
|
||||
|
||||
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
|
||||
}
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
result = ModelConfigConverter.convert(mock_app_config)
|
||||
assert result.mode == LLMMode.COMPLETION
|
||||
|
||||
def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker):
|
||||
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
|
||||
ModelPropertyKey.MODE: "invalid"
|
||||
}
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
result = ModelConfigConverter.convert(mock_app_config)
|
||||
assert result.mode == LLMMode.CHAT
|
||||
|
||||
# =============================
|
||||
# Credential Errors
|
||||
# =============================
|
||||
|
||||
def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
|
||||
mock_provider_bundle.configuration.get_current_credentials.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(ProviderTokenNotInitError):
|
||||
ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
# =============================
|
||||
# Provider Model Errors
|
||||
# =============================
|
||||
|
||||
def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
|
||||
mock_provider_bundle.configuration.get_provider_model.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status", "expected_exception"),
|
||||
[
|
||||
(ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError),
|
||||
(ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError),
|
||||
(ModelStatus.QUOTA_EXCEEDED, QuotaExceededError),
|
||||
],
|
||||
)
|
||||
def test_convert_provider_model_status_errors(
|
||||
self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception
|
||||
):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.status = status
|
||||
mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(expected_exception):
|
||||
ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
# =============================
|
||||
# Schema Errors
|
||||
# =============================
|
||||
|
||||
def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
|
||||
mock_provider_bundle.model_type_instance.get_model_schema.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
# =============================
|
||||
# Edge Cases
|
||||
# =============================
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{},
|
||||
{"stop": []},
|
||||
{"stop": ["END"], "max_tokens": 100},
|
||||
],
|
||||
)
|
||||
def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters):
|
||||
mock_app_config.model.parameters = parameters.copy()
|
||||
|
||||
result = ModelConfigConverter.convert(mock_app_config)
|
||||
|
||||
if "stop" in parameters:
|
||||
assert result.stop == parameters.get("stop")
|
||||
expected_params = parameters.copy()
|
||||
expected_params.pop("stop", None)
|
||||
assert result.parameters == expected_params
|
||||
else:
|
||||
assert result.stop == []
|
||||
assert result.parameters == parameters
|
||||
@ -0,0 +1,230 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Target
|
||||
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_completion_params():
|
||||
return {"temperature": 0.7, "stop": ["\n"]}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_model_list():
|
||||
model = MagicMock()
|
||||
model.model = "gpt-4"
|
||||
model.model_properties = {"mode": "chat"}
|
||||
return [model]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_entities():
|
||||
provider = MagicMock()
|
||||
provider.provider = "openai/gpt"
|
||||
return [provider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_config():
|
||||
return {
|
||||
"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}}
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Test Class
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestModelConfigManager:
|
||||
# ==========================================================
|
||||
# convert
|
||||
# ==========================================================
|
||||
|
||||
def test_convert_success(self, valid_config):
|
||||
result = ModelConfigManager.convert(valid_config)
|
||||
|
||||
assert result.provider == "openai/gpt"
|
||||
assert result.model == "gpt-4"
|
||||
assert result.parameters == {"temperature": 0.5}
|
||||
assert result.stop == ["END"]
|
||||
|
||||
def test_convert_missing_model(self):
|
||||
with pytest.raises(ValueError, match="model is required"):
|
||||
ModelConfigManager.convert({})
|
||||
|
||||
def test_convert_without_stop(self):
|
||||
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}}
|
||||
result = ModelConfigManager.convert(config)
|
||||
assert result.stop == []
|
||||
assert result.parameters == {"temperature": 0.9}
|
||||
|
||||
# ==========================================================
|
||||
# validate_model_completion_params
|
||||
# ==========================================================
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_cp",
|
||||
[None, "string", 123, []],
|
||||
)
|
||||
def test_validate_model_completion_params_invalid_type(self, invalid_cp):
|
||||
with pytest.raises(ValueError, match="must be of object type"):
|
||||
ModelConfigManager.validate_model_completion_params(invalid_cp)
|
||||
|
||||
def test_validate_model_completion_params_default_stop(self):
|
||||
cp = {"temperature": 0.2}
|
||||
result = ModelConfigManager.validate_model_completion_params(cp)
|
||||
assert result["stop"] == []
|
||||
|
||||
def test_validate_model_completion_params_invalid_stop_type(self):
|
||||
cp = {"stop": "invalid"}
|
||||
with pytest.raises(ValueError, match="must be of list type"):
|
||||
ModelConfigManager.validate_model_completion_params(cp)
|
||||
|
||||
def test_validate_model_completion_params_stop_length_exceeded(self):
|
||||
cp = {"stop": [1, 2, 3, 4, 5]}
|
||||
with pytest.raises(ValueError, match="less than 4"):
|
||||
ModelConfigManager.validate_model_completion_params(cp)
|
||||
|
||||
# ==========================================================
|
||||
# validate_and_set_defaults
|
||||
# ==========================================================
|
||||
|
||||
def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list):
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
|
||||
|
||||
updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config)
|
||||
|
||||
assert updated_config["model"]["mode"] == "chat"
|
||||
assert keys == ["model"]
|
||||
|
||||
def test_validate_and_set_defaults_missing_model(self):
|
||||
with pytest.raises(ValueError, match="model is required"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", {})
|
||||
|
||||
def test_validate_and_set_defaults_model_not_dict(self):
|
||||
with pytest.raises(ValueError, match="object type"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"})
|
||||
|
||||
def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities):
|
||||
config = {"model": {"name": "gpt-4", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
with pytest.raises(ValueError, match="model.provider is required"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities):
|
||||
config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
with pytest.raises(ValueError, match="model.provider is required"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities):
|
||||
config = {"model": {"provider": "openai/gpt", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
with pytest.raises(ValueError, match="model.name is required"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities):
|
||||
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="must be in the specified model list"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list):
|
||||
config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
|
||||
|
||||
with pytest.raises(ValueError, match="must be in the specified model list"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities):
|
||||
model = MagicMock()
|
||||
model.model = "gpt-4"
|
||||
model.model_properties = {}
|
||||
|
||||
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model]
|
||||
|
||||
updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
assert updated_config["model"]["mode"] == "completion"
|
||||
|
||||
def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list):
|
||||
config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}}
|
||||
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
mock_factory.return_value.get_providers.return_value = provider_entities
|
||||
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
|
||||
|
||||
with pytest.raises(ValueError, match="completion_params is required"):
|
||||
ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list):
|
||||
"""
|
||||
Covers branch where provider does not contain '/' and
|
||||
ModelProviderID conversion is triggered (line 64).
|
||||
"""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "openai", # no slash -> triggers conversion
|
||||
"name": "gpt-4",
|
||||
"completion_params": {},
|
||||
}
|
||||
}
|
||||
|
||||
# Mock ModelProviderID to return formatted provider
|
||||
mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID")
|
||||
mock_provider_id.return_value = "openai/gpt"
|
||||
|
||||
# Mock provider factory
|
||||
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
|
||||
provider_entity = MagicMock()
|
||||
provider_entity.provider = "openai/gpt"
|
||||
mock_factory.return_value.get_providers.return_value = [provider_entity]
|
||||
|
||||
# Mock provider manager
|
||||
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
|
||||
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
|
||||
|
||||
updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config)
|
||||
|
||||
# Ensure conversion happened
|
||||
mock_provider_id.assert_called_once_with("openai")
|
||||
assert updated_config["model"]["provider"] == "openai/gpt"
|
||||
@ -0,0 +1,292 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.prompt_template.manager import (
|
||||
PromptTemplateConfigManager,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Helpers
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class DummyEnumValue:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class DummyPromptType:
|
||||
def __init__(self):
|
||||
self.SIMPLE = "simple"
|
||||
self.ADVANCED = "advanced"
|
||||
|
||||
def value_of(self, value):
|
||||
return value
|
||||
|
||||
def __iter__(self):
|
||||
return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")])
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Convert Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestPromptTemplateConfigManagerConvert:
|
||||
def test_convert_missing_prompt_type_raises(self):
|
||||
with pytest.raises(ValueError, match="prompt_type is required"):
|
||||
PromptTemplateConfigManager.convert({})
|
||||
|
||||
def test_convert_simple_prompt(self, mocker):
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = DummyPromptType()
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
|
||||
mock_prompt_entity_cls.return_value = "simple_entity"
|
||||
|
||||
config = {"prompt_type": "simple", "pre_prompt": "hello"}
|
||||
|
||||
result = PromptTemplateConfigManager.convert(config)
|
||||
|
||||
assert result == "simple_entity"
|
||||
mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello")
|
||||
|
||||
def test_convert_advanced_chat_valid(self, mocker):
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = DummyPromptType()
|
||||
mock_prompt_entity_cls.return_value = "advanced_entity"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of",
|
||||
return_value="role_enum",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity",
|
||||
return_value="chat_msg",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity",
|
||||
return_value="chat_template",
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]},
|
||||
}
|
||||
|
||||
result = PromptTemplateConfigManager.convert(config)
|
||||
|
||||
assert result == "advanced_entity"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
{"text": 123, "role": "user"},
|
||||
{"text": "hi", "role": 123},
|
||||
],
|
||||
)
|
||||
def test_convert_advanced_invalid_message_fields(self, mocker, message):
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = DummyPromptType()
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"chat_prompt_config": {"prompt": [message]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.convert(config)
|
||||
|
||||
def test_convert_advanced_completion_with_roles(self, mocker):
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = DummyPromptType()
|
||||
mock_prompt_entity_cls.return_value = "advanced_entity"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity",
|
||||
return_value="completion_template",
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"completion_prompt_config": {
|
||||
"prompt": {"text": "complete"},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "U",
|
||||
"assistant_prefix": "A",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = PromptTemplateConfigManager.convert(config)
|
||||
|
||||
assert result == "advanced_entity"
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# validate_and_set_defaults
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestValidateAndSetDefaults:
|
||||
def setup_method(self):
|
||||
self.valid_model = {"mode": "chat"}
|
||||
|
||||
def _patch_prompt_type(self, mocker):
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = DummyPromptType()
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
return mock_prompt_entity_cls
|
||||
|
||||
def test_default_prompt_type_set(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {"model": self.valid_model}
|
||||
|
||||
result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
assert result["prompt_type"] == "simple"
|
||||
assert isinstance(keys, list)
|
||||
|
||||
def test_invalid_prompt_type_raises(self, mocker):
|
||||
class InvalidEnum(DummyPromptType):
|
||||
def __iter__(self):
|
||||
return iter([DummyEnumValue("valid")])
|
||||
|
||||
mock_prompt_entity_cls = MagicMock()
|
||||
mock_prompt_entity_cls.PromptType = InvalidEnum()
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
|
||||
mock_prompt_entity_cls,
|
||||
)
|
||||
|
||||
config = {"prompt_type": "invalid", "model": self.valid_model}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_invalid_chat_prompt_config_type(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "simple",
|
||||
"chat_prompt_config": "invalid",
|
||||
"model": self.valid_model,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_simple_mode_invalid_pre_prompt_type(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "simple",
|
||||
"pre_prompt": 123,
|
||||
"model": self.valid_model,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_advanced_requires_one_config(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"chat_prompt_config": {},
|
||||
"completion_prompt_config": {},
|
||||
"model": {"mode": "chat"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_advanced_invalid_model_mode(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"chat_prompt_config": {"prompt": []},
|
||||
"model": {"mode": "invalid"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_advanced_chat_prompt_length_exceeds(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"chat_prompt_config": {"prompt": [{}] * 11},
|
||||
"model": {"mode": "chat"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
|
||||
|
||||
def test_completion_prefix_defaults_set_when_empty(self, mocker):
|
||||
self._patch_prompt_type(mocker)
|
||||
|
||||
config = {
|
||||
"prompt_type": "advanced",
|
||||
"completion_prompt_config": {
|
||||
"prompt": {"text": "hi"},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "",
|
||||
"assistant_prefix": "",
|
||||
},
|
||||
},
|
||||
"model": {"mode": "completion"},
|
||||
}
|
||||
|
||||
updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config)
|
||||
|
||||
roles = updated["completion_prompt_config"]["conversation_histories_role"]
|
||||
assert roles["user_prefix"] == "Human"
|
||||
assert roles["assistant_prefix"] == "Assistant"
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# validate_post_prompt
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestValidatePostPrompt:
|
||||
@pytest.mark.parametrize("value", [None, ""])
|
||||
def test_post_prompt_defaults(self, value):
|
||||
config = {"post_prompt": value}
|
||||
result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config)
|
||||
assert result["post_prompt"] == ""
|
||||
|
||||
def test_post_prompt_invalid_type(self):
|
||||
config = {"post_prompt": 123}
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config)
|
||||
@ -0,0 +1,286 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.variables.manager import (
|
||||
BasicVariablesConfigManager,
|
||||
)
|
||||
from dify_graph.variables.input_entities import VariableEntityType
|
||||
|
||||
|
||||
class TestBasicVariablesConfigManagerConvert:
|
||||
def test_convert_empty_config(self):
|
||||
config = {}
|
||||
|
||||
variables, external = BasicVariablesConfigManager.convert(config)
|
||||
|
||||
assert variables == []
|
||||
assert external == []
|
||||
|
||||
def test_convert_external_data_tools_enabled_and_disabled(self, mocker):
|
||||
config = {
|
||||
"external_data_tools": [
|
||||
{"enabled": False},
|
||||
{
|
||||
"enabled": True,
|
||||
"variable": "ext_var",
|
||||
"type": "tool_type",
|
||||
"config": {"k": "v"},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
variables, external = BasicVariablesConfigManager.convert(config)
|
||||
|
||||
assert variables == []
|
||||
assert len(external) == 1
|
||||
assert external[0].variable == "ext_var"
|
||||
assert external[0].type == "tool_type"
|
||||
|
||||
def test_convert_user_input_form_variable_types(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.TEXT_INPUT: {
|
||||
"variable": "name",
|
||||
"label": "Name",
|
||||
"description": "desc",
|
||||
"required": True,
|
||||
"max_length": 50,
|
||||
}
|
||||
},
|
||||
{
|
||||
VariableEntityType.SELECT: {
|
||||
"variable": "choice",
|
||||
"label": "Choice",
|
||||
"options": ["a", "b"],
|
||||
}
|
||||
},
|
||||
{
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL: {
|
||||
"variable": "ext",
|
||||
"type": "tool",
|
||||
"config": {"x": 1},
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
variables, external = BasicVariablesConfigManager.convert(config)
|
||||
|
||||
assert len(variables) == 2
|
||||
assert len(external) == 1
|
||||
|
||||
def test_convert_external_data_tool_without_config_skipped(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL: {
|
||||
"variable": "ext",
|
||||
"type": "tool",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
variables, external = BasicVariablesConfigManager.convert(config)
|
||||
|
||||
assert variables == []
|
||||
assert external == []
|
||||
|
||||
|
||||
class TestValidateVariablesAndSetDefaults:
|
||||
def test_validate_sets_empty_user_input_form_if_missing(self):
|
||||
config = {}
|
||||
|
||||
updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
assert updated["user_input_form"] == []
|
||||
assert "user_input_form" in keys
|
||||
|
||||
def test_validate_user_input_form_not_list_raises(self):
|
||||
config = {"user_input_form": "invalid"}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_invalid_key_raises(self):
|
||||
config = {"user_input_form": [{"invalid": {}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_missing_label_raises(self):
|
||||
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_label_not_string_raises(self):
|
||||
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_missing_variable_raises(self):
|
||||
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_variable_not_string_raises(self):
|
||||
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variable_name",
|
||||
["1invalid", "invalid space", "", None],
|
||||
)
|
||||
def test_validate_variable_invalid_pattern_raises(self, variable_name):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.TEXT_INPUT: {
|
||||
"label": "Name",
|
||||
"variable": variable_name,
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_required_default_and_type(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.TEXT_INPUT: {
|
||||
"label": "Name",
|
||||
"variable": "valid_name",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False
|
||||
|
||||
def test_validate_required_not_bool_raises(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.TEXT_INPUT: {
|
||||
"label": "Name",
|
||||
"variable": "valid_name",
|
||||
"required": "yes",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_select_options_default_not_in_options_raises(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.SELECT: {
|
||||
"label": "Choice",
|
||||
"variable": "choice",
|
||||
"options": ["a", "b"],
|
||||
"default": "c",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
def test_validate_select_options_not_list_raises(self):
|
||||
config = {
|
||||
"user_input_form": [
|
||||
{
|
||||
VariableEntityType.SELECT: {
|
||||
"label": "Choice",
|
||||
"variable": "choice",
|
||||
"options": "not_list",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
|
||||
|
||||
|
||||
class TestValidateExternalDataToolsAndSetDefaults:
|
||||
def test_validate_sets_empty_external_data_tools_if_missing(self):
|
||||
config = {}
|
||||
|
||||
updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
|
||||
|
||||
assert updated["external_data_tools"] == []
|
||||
assert "external_data_tools" in keys
|
||||
|
||||
def test_validate_external_data_tools_not_list_raises(self):
|
||||
config = {"external_data_tools": "invalid"}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
|
||||
|
||||
def test_validate_disabled_tool_skipped(self, mocker):
|
||||
config = {"external_data_tools": [{"enabled": False}]}
|
||||
|
||||
spy = mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config"
|
||||
)
|
||||
|
||||
updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
|
||||
|
||||
spy.assert_not_called()
|
||||
assert updated["external_data_tools"][0]["enabled"] is False
|
||||
|
||||
def test_validate_enabled_tool_missing_type_raises(self):
|
||||
config = {"external_data_tools": [{"enabled": True, "config": {}}]}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
|
||||
|
||||
def test_validate_enabled_tool_calls_factory(self, mocker):
|
||||
config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]}
|
||||
|
||||
spy = mocker.patch(
|
||||
"core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config"
|
||||
)
|
||||
|
||||
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config)
|
||||
|
||||
spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1})
|
||||
|
||||
|
||||
class TestValidateAndSetDefaultsIntegration:
|
||||
def test_validate_and_set_defaults_calls_both(self, mocker):
|
||||
config = {}
|
||||
|
||||
spy_var = mocker.patch.object(
|
||||
BasicVariablesConfigManager,
|
||||
"validate_variables_and_set_defaults",
|
||||
return_value=(config, ["user_input_form"]),
|
||||
)
|
||||
spy_ext = mocker.patch.object(
|
||||
BasicVariablesConfigManager,
|
||||
"validate_external_data_tools_and_set_defaults",
|
||||
return_value=(config, ["external_data_tools"]),
|
||||
)
|
||||
|
||||
updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config)
|
||||
|
||||
spy_var.assert_called_once()
|
||||
spy_ext.assert_called_once()
|
||||
assert "user_input_form" in keys
|
||||
assert "external_data_tools" in keys
|
||||
assert updated == config
|
||||
@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
||||
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
||||
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
|
||||
|
||||
class TestAdditionalFeatureManagers:
|
||||
def test_opening_statement_validate_defaults(self):
|
||||
config, keys = OpeningStatementConfigManager.validate_and_set_defaults({})
|
||||
assert config["opening_statement"] == ""
|
||||
assert config["suggested_questions"] == []
|
||||
assert set(keys) == {"opening_statement", "suggested_questions"}
|
||||
|
||||
def test_opening_statement_validate_types(self):
|
||||
with pytest.raises(ValueError):
|
||||
OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123})
|
||||
with pytest.raises(ValueError):
|
||||
OpeningStatementConfigManager.validate_and_set_defaults(
|
||||
{"opening_statement": "hi", "suggested_questions": "bad"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
OpeningStatementConfigManager.validate_and_set_defaults(
|
||||
{"opening_statement": "hi", "suggested_questions": [1]}
|
||||
)
|
||||
|
||||
def test_opening_statement_convert(self):
|
||||
opening, questions = OpeningStatementConfigManager.convert(
|
||||
{"opening_statement": "hello", "suggested_questions": ["q1"]}
|
||||
)
|
||||
assert opening == "hello"
|
||||
assert questions == ["q1"]
|
||||
|
||||
def test_retrieval_resource_validate(self):
|
||||
config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({})
|
||||
assert config["retriever_resource"]["enabled"] is False
|
||||
assert keys == ["retriever_resource"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"})
|
||||
with pytest.raises(ValueError):
|
||||
RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}})
|
||||
|
||||
def test_retrieval_resource_convert(self):
|
||||
assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True
|
||||
assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False
|
||||
|
||||
def test_speech_to_text_validate_and_convert(self):
|
||||
config, keys = SpeechToTextConfigManager.validate_and_set_defaults({})
|
||||
assert config["speech_to_text"]["enabled"] is False
|
||||
assert keys == ["speech_to_text"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"})
|
||||
with pytest.raises(ValueError):
|
||||
SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}})
|
||||
|
||||
assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True
|
||||
assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False
|
||||
|
||||
def test_suggested_questions_after_answer_validate_and_convert(self):
|
||||
config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({})
|
||||
assert config["suggested_questions_after_answer"]["enabled"] is False
|
||||
assert keys == ["suggested_questions_after_answer"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
{"suggested_questions_after_answer": "bad"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
{"suggested_questions_after_answer": {"enabled": "yes"}}
|
||||
)
|
||||
|
||||
assert (
|
||||
SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}})
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}})
|
||||
is False
|
||||
)
|
||||
|
||||
def test_text_to_speech_validate_and_convert(self):
|
||||
config, keys = TextToSpeechConfigManager.validate_and_set_defaults({})
|
||||
assert config["text_to_speech"]["enabled"] is False
|
||||
assert keys == ["text_to_speech"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"})
|
||||
with pytest.raises(ValueError):
|
||||
TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}})
|
||||
|
||||
result = TextToSpeechConfigManager.convert(
|
||||
{"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}}
|
||||
)
|
||||
assert isinstance(result, TextToSpeechEntity)
|
||||
assert result.voice == "v"
|
||||
assert result.language == "en"
|
||||
|
||||
def test_more_like_this_convert_and_validate(self):
|
||||
config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({})
|
||||
assert config["more_like_this"]["enabled"] is False
|
||||
assert keys == ["more_like_this"]
|
||||
|
||||
assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True
|
||||
assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False
|
||||
with pytest.raises(ValueError):
|
||||
MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"})
|
||||
@ -0,0 +1,180 @@
|
||||
from collections import UserDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
|
||||
|
||||
class TestBaseAppConfigManager:
|
||||
@pytest.fixture
|
||||
def mock_config_dict(self):
|
||||
return {"key": "value", "another": 123}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_additional_features(self, mocker):
|
||||
mock_instance = MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.AppAdditionalFeatures",
|
||||
return_value=mock_instance,
|
||||
)
|
||||
return mock_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_managers(self, mocker):
|
||||
retrieval = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
|
||||
return_value="retrieval_result",
|
||||
)
|
||||
file_upload = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert",
|
||||
return_value="file_upload_result",
|
||||
)
|
||||
opening_statement = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert",
|
||||
return_value=("opening_result", "suggested_result"),
|
||||
)
|
||||
suggested_after = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert",
|
||||
return_value="suggested_after_result",
|
||||
)
|
||||
more_like_this = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert",
|
||||
return_value="more_like_this_result",
|
||||
)
|
||||
speech_to_text = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert",
|
||||
return_value="speech_to_text_result",
|
||||
)
|
||||
text_to_speech = mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert",
|
||||
return_value="text_to_speech_result",
|
||||
)
|
||||
|
||||
return {
|
||||
"retrieval": retrieval,
|
||||
"file_upload": file_upload,
|
||||
"opening_statement": opening_statement,
|
||||
"suggested_after": suggested_after,
|
||||
"more_like_this": more_like_this,
|
||||
"speech_to_text": speech_to_text,
|
||||
"text_to_speech": text_to_speech,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("app_mode", "expected_is_vision"),
|
||||
[
|
||||
("CHAT", True),
|
||||
("COMPLETION", True),
|
||||
("AGENT_CHAT", True),
|
||||
("OTHER", False),
|
||||
],
|
||||
)
|
||||
def test_convert_features_all_modes(
|
||||
self,
|
||||
mocker,
|
||||
mock_config_dict,
|
||||
mock_app_additional_features,
|
||||
mock_managers,
|
||||
app_mode,
|
||||
expected_is_vision,
|
||||
):
|
||||
# Arrange
|
||||
mock_app_mode = MagicMock()
|
||||
mock_app_mode.CHAT = "CHAT"
|
||||
mock_app_mode.COMPLETION = "COMPLETION"
|
||||
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.AppMode",
|
||||
mock_app_mode,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode)
|
||||
|
||||
# Assert
|
||||
assert result == mock_app_additional_features
|
||||
mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
mock_managers["file_upload"].assert_called_once()
|
||||
_, kwargs = mock_managers["file_upload"].call_args
|
||||
assert kwargs["config"] == dict(mock_config_dict.items())
|
||||
assert kwargs["is_vision"] is expected_is_vision
|
||||
|
||||
mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items()))
|
||||
|
||||
def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers):
|
||||
# Arrange
|
||||
empty_config = {}
|
||||
mock_app_mode = MagicMock()
|
||||
mock_app_mode.CHAT = "CHAT"
|
||||
mock_app_mode.COMPLETION = "COMPLETION"
|
||||
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.AppMode",
|
||||
mock_app_mode,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = BaseAppConfigManager.convert_features(empty_config, "CHAT")
|
||||
|
||||
# Assert
|
||||
assert result == mock_app_additional_features
|
||||
for manager in mock_managers.values():
|
||||
assert manager.called
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config",
|
||||
[
|
||||
None,
|
||||
"string",
|
||||
123,
|
||||
12.34,
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_convert_features_invalid_config_raises(self, invalid_config):
|
||||
# Act & Assert
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
BaseAppConfigManager.convert_features(invalid_config, "CHAT")
|
||||
|
||||
def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict):
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
|
||||
side_effect=RuntimeError("manager failure"),
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
BaseAppConfigManager.convert_features(mock_config_dict, "CHAT")
|
||||
|
||||
def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers):
|
||||
# Arrange
|
||||
class CustomMapping(UserDict):
|
||||
pass
|
||||
|
||||
custom_config = CustomMapping({"a": 1})
|
||||
|
||||
mock_app_mode = MagicMock()
|
||||
mock_app_mode.CHAT = "CHAT"
|
||||
mock_app_mode.COMPLETION = "COMPLETION"
|
||||
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
|
||||
|
||||
mocker.patch(
|
||||
"core.app.app_config.base_app_config_manager.AppMode",
|
||||
mock_app_mode,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = BaseAppConfigManager.convert_features(custom_config, "CHAT")
|
||||
|
||||
# Assert
|
||||
assert result == mock_app_additional_features
|
||||
for manager in mock_managers.values():
|
||||
assert manager.called
|
||||
43
api/tests/unit_tests/core/app/app_config/test_entities.py
Normal file
43
api/tests/unit_tests/core/app/app_config/test_entities.py
Normal file
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetRetrieveConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
class TestAppConfigEntities:
|
||||
def test_variable_entity_coerces_none_description_and_options(self):
|
||||
entity = VariableEntity(
|
||||
variable="query",
|
||||
label="Query",
|
||||
description=None,
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
options=None,
|
||||
)
|
||||
|
||||
assert entity.description == ""
|
||||
assert entity.options == []
|
||||
|
||||
def test_variable_entity_rejects_invalid_json_schema(self):
|
||||
with pytest.raises(ValueError):
|
||||
VariableEntity(
|
||||
variable="query",
|
||||
label="Query",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
json_schema={"type": "string", "minLength": "bad"},
|
||||
)
|
||||
|
||||
def test_prompt_template_value_of(self):
|
||||
assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplateEntity.PromptType.value_of("missing")
|
||||
|
||||
def test_dataset_retrieve_strategy_value_of(self):
|
||||
assert (
|
||||
DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single")
|
||||
== DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing")
|
||||
@ -0,0 +1,222 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import (
|
||||
WorkflowVariablesConfigManager,
|
||||
)
|
||||
|
||||
# =============================
|
||||
# Fixtures
|
||||
# =============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(mocker):
|
||||
workflow = mocker.MagicMock()
|
||||
workflow.graph_dict = {"nodes": []}
|
||||
return workflow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_variable_entity(mocker):
|
||||
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_entity(mocker):
|
||||
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity")
|
||||
|
||||
|
||||
# =============================
|
||||
# Test Convert (user_input_form)
|
||||
# =============================
|
||||
|
||||
|
||||
class TestWorkflowVariablesConfigManagerConvert:
|
||||
def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity):
|
||||
# Arrange
|
||||
input_variables = [{"name": "var1"}, {"name": "var2"}]
|
||||
mock_workflow.user_input_form.return_value = input_variables
|
||||
mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert(mock_workflow)
|
||||
|
||||
# Assert
|
||||
assert result == [{"validated": v} for v in input_variables]
|
||||
assert mock_variable_entity.model_validate.call_count == 2
|
||||
|
||||
def test_convert_empty_list(self, mock_workflow, mock_variable_entity):
|
||||
# Arrange
|
||||
mock_workflow.user_input_form.return_value = []
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert(mock_workflow)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_variable_entity.model_validate.assert_not_called()
|
||||
|
||||
def test_convert_none_returned_raises(self, mock_workflow):
|
||||
# Arrange
|
||||
mock_workflow.user_input_form.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TypeError):
|
||||
WorkflowVariablesConfigManager.convert(mock_workflow)
|
||||
|
||||
def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity):
|
||||
# Arrange
|
||||
mock_workflow.user_input_form.return_value = [{"invalid": "data"}]
|
||||
mock_variable_entity.model_validate.side_effect = ValueError("validation error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowVariablesConfigManager.convert(mock_workflow)
|
||||
|
||||
|
||||
# =============================
|
||||
# Test convert_rag_pipeline_variable
|
||||
# =============================
|
||||
|
||||
|
||||
class TestWorkflowVariablesConfigManagerConvertRag:
|
||||
def test_no_rag_pipeline_variables(self, mock_workflow):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = []
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_rag_pipeline_none(self, mock_workflow):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = None
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": "node1"},
|
||||
]
|
||||
mock_workflow.graph_dict = {"nodes": []}
|
||||
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}]
|
||||
|
||||
def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": "node1"},
|
||||
{"variable": "var2", "belong_to_node_id": "node1"},
|
||||
]
|
||||
|
||||
mock_workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["validated"]["variable"] == "var2"
|
||||
|
||||
def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": "node1"},
|
||||
{"variable": "var2", "belong_to_node_id": "node1"},
|
||||
]
|
||||
|
||||
mock_workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["validated"]["variable"] == "var2"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("belong_to_node_id", "expected_count"),
|
||||
[
|
||||
("node1", 1),
|
||||
("shared", 1),
|
||||
("other_node", 0),
|
||||
],
|
||||
)
|
||||
def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": belong_to_node_id},
|
||||
]
|
||||
mock_workflow.graph_dict = {"nodes": []}
|
||||
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": "node1"},
|
||||
]
|
||||
|
||||
mock_workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
|
||||
|
||||
# Act
|
||||
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
|
||||
def test_validation_error_propagates(self, mock_workflow, mock_rag_entity):
|
||||
# Arrange
|
||||
mock_workflow.rag_pipeline_variables = [
|
||||
{"variable": "var1", "belong_to_node_id": "node1"},
|
||||
]
|
||||
mock_workflow.graph_dict = {"nodes": []}
|
||||
mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
|
||||
@ -0,0 +1,12 @@
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
|
||||
|
||||
class TestQueueEntities:
|
||||
def test_get_stop_reason_for_known_stop_by(self):
|
||||
event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
|
||||
assert event.get_stop_reason() == "Stopped by user."
|
||||
|
||||
def test_get_stop_reason_for_unknown_stop_by(self):
|
||||
event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
|
||||
event.stopped_by = "unknown"
|
||||
assert event.get_stop_reason() == "Stopped by unknown reason."
|
||||
@ -0,0 +1,17 @@
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
|
||||
|
||||
class TestRagPipelineInvokeEntity:
|
||||
def test_defaults_and_fields(self):
|
||||
entity = RagPipelineInvokeEntity(
|
||||
pipeline_id="pipe-1",
|
||||
application_generate_entity={"foo": "bar"},
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="workflow-1",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id is None
|
||||
assert entity.workflow_thread_pool_id is None
|
||||
assert entity.streaming is True
|
||||
78
api/tests/unit_tests/core/app/entities/test_task_entities.py
Normal file
78
api/tests/unit_tests/core/app/entities/test_task_entities.py
Normal file
@ -0,0 +1,78 @@
|
||||
from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestTaskEntities:
|
||||
def test_node_start_to_ignore_detail_dict(self):
|
||||
data = NodeStartStreamResponse.Data(
|
||||
id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"foo": "bar"},
|
||||
created_at=1,
|
||||
)
|
||||
response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
|
||||
|
||||
payload = response.to_ignore_detail_dict()
|
||||
|
||||
assert payload["event"] == StreamEvent.NODE_STARTED.value
|
||||
assert payload["data"]["inputs"] is None
|
||||
assert payload["data"]["extras"] == {}
|
||||
|
||||
def test_node_finish_to_ignore_detail_dict(self):
|
||||
data = NodeFinishStreamResponse.Data(
|
||||
id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"foo": "bar"},
|
||||
process_data={"step": 1},
|
||||
outputs={"answer": "ok"},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
)
|
||||
response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
|
||||
|
||||
payload = response.to_ignore_detail_dict()
|
||||
|
||||
assert payload["event"] == StreamEvent.NODE_FINISHED.value
|
||||
assert payload["data"]["inputs"] is None
|
||||
assert payload["data"]["outputs"] is None
|
||||
assert payload["data"]["files"] == []
|
||||
|
||||
def test_node_retry_to_ignore_detail_dict(self):
|
||||
data = NodeRetryStreamResponse.Data(
|
||||
id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"foo": "bar"},
|
||||
process_data={"step": 1},
|
||||
outputs={"answer": "ok"},
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
retry_index=2,
|
||||
)
|
||||
response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
|
||||
|
||||
payload = response.to_ignore_detail_dict()
|
||||
|
||||
assert payload["event"] == StreamEvent.NODE_RETRY.value
|
||||
assert payload["data"]["retry_index"] == 2
|
||||
assert payload["data"]["outputs"] is None
|
||||
163
api/tests/unit_tests/core/app/features/test_annotation_reply.py
Normal file
163
api/tests/unit_tests/core/app/features/test_annotation_reply.py
Normal file
@ -0,0 +1,163 @@
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
|
||||
|
||||
class TestAnnotationReplyFeature:
|
||||
def test_query_returns_none_when_setting_missing(self):
|
||||
feature = AnnotationReplyFeature()
|
||||
|
||||
with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
result = feature.query(
|
||||
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
message=SimpleNamespace(id="msg-1"),
|
||||
query="hi",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_query_returns_none_when_binding_missing(self):
|
||||
feature = AnnotationReplyFeature()
|
||||
annotation_setting = SimpleNamespace(collection_binding_detail=None)
|
||||
|
||||
with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = annotation_setting
|
||||
|
||||
result = feature.query(
|
||||
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
message=SimpleNamespace(id="msg-1"),
|
||||
query="hi",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_query_returns_annotation_and_records_history_for_api(self):
|
||||
feature = AnnotationReplyFeature()
|
||||
annotation_setting = SimpleNamespace(
|
||||
score_threshold=None,
|
||||
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
|
||||
)
|
||||
dataset_binding = SimpleNamespace(id="binding-1")
|
||||
annotation = SimpleNamespace(
|
||||
id="ann-1",
|
||||
question_text="question",
|
||||
content="content",
|
||||
account_id="acct-1",
|
||||
account=SimpleNamespace(name="Alice"),
|
||||
)
|
||||
document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8})
|
||||
vector_instance = Mock()
|
||||
vector_instance.search_by_vector.return_value = [document]
|
||||
|
||||
with (
|
||||
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
|
||||
patch(
|
||||
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
|
||||
) as mock_binding_service,
|
||||
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
|
||||
patch(
|
||||
"core.app.features.annotation_reply.annotation_reply.AppAnnotationService"
|
||||
) as mock_annotation_service,
|
||||
):
|
||||
mock_db.session.scalar.return_value = annotation_setting
|
||||
mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding
|
||||
mock_vector.return_value = vector_instance
|
||||
mock_annotation_service.get_annotation_by_id.return_value = annotation
|
||||
|
||||
result = feature.query(
|
||||
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
message=SimpleNamespace(id="msg-1"),
|
||||
query="hi",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
assert result == annotation
|
||||
mock_annotation_service.add_annotation_history.assert_called_once()
|
||||
_, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0]
|
||||
assert from_source == "api"
|
||||
assert score == 0.8
|
||||
|
||||
def test_query_returns_annotation_and_records_history_for_console(self):
|
||||
feature = AnnotationReplyFeature()
|
||||
annotation_setting = SimpleNamespace(
|
||||
score_threshold=0.5,
|
||||
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
|
||||
)
|
||||
dataset_binding = SimpleNamespace(id="binding-1")
|
||||
annotation = SimpleNamespace(
|
||||
id="ann-1",
|
||||
question_text="question",
|
||||
content="content",
|
||||
account_id="acct-1",
|
||||
account=None,
|
||||
)
|
||||
document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6})
|
||||
vector_instance = Mock()
|
||||
vector_instance.search_by_vector.return_value = [document]
|
||||
|
||||
with (
|
||||
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
|
||||
patch(
|
||||
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
|
||||
) as mock_binding_service,
|
||||
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
|
||||
patch(
|
||||
"core.app.features.annotation_reply.annotation_reply.AppAnnotationService"
|
||||
) as mock_annotation_service,
|
||||
):
|
||||
mock_db.session.scalar.return_value = annotation_setting
|
||||
mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding
|
||||
mock_vector.return_value = vector_instance
|
||||
mock_annotation_service.get_annotation_by_id.return_value = annotation
|
||||
|
||||
result = feature.query(
|
||||
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
message=SimpleNamespace(id="msg-1"),
|
||||
query="hi",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
)
|
||||
|
||||
assert result == annotation
|
||||
_, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0]
|
||||
assert from_source == "console"
|
||||
|
||||
def test_query_logs_and_returns_none_on_exception(self, caplog):
|
||||
feature = AnnotationReplyFeature()
|
||||
annotation_setting = SimpleNamespace(
|
||||
score_threshold=None,
|
||||
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
|
||||
patch(
|
||||
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
|
||||
) as mock_binding_service,
|
||||
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
|
||||
):
|
||||
mock_db.session.scalar.return_value = annotation_setting
|
||||
mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1")
|
||||
mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = feature.query(
|
||||
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
message=SimpleNamespace(id="msg-1"),
|
||||
query="hi",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert "Query annotation failed" in caplog.text
|
||||
@ -0,0 +1,30 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
|
||||
|
||||
class TestHostingModerationFeature:
|
||||
def test_check_aggregates_text_and_calls_moderation(self):
|
||||
application_generate_entity = Mock()
|
||||
application_generate_entity.model_conf = {"model": "mock"}
|
||||
application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
prompt_messages = [
|
||||
SimpleNamespace(content="hello"),
|
||||
SimpleNamespace(content=123),
|
||||
SimpleNamespace(content="world"),
|
||||
]
|
||||
|
||||
with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check:
|
||||
mock_check.return_value = True
|
||||
|
||||
feature = HostingModerationFeature()
|
||||
result = feature.check(application_generate_entity, prompt_messages)
|
||||
|
||||
assert result is True
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_config=application_generate_entity.model_conf,
|
||||
text="hello\nworld\n",
|
||||
)
|
||||
19
api/tests/unit_tests/core/app/layers/test_suspend_layer.py
Normal file
19
api/tests/unit_tests/core/app/layers/test_suspend_layer.py
Normal file
@ -0,0 +1,19 @@
|
||||
from core.app.layers.suspend_layer import SuspendLayer
|
||||
from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class TestSuspendLayer:
|
||||
def test_on_event_accepts_paused_event(self):
|
||||
layer = SuspendLayer()
|
||||
assert layer.is_paused() is False
|
||||
layer.on_graph_start()
|
||||
assert layer.is_paused() is False
|
||||
layer.on_event(GraphRunPausedEvent())
|
||||
assert layer.is_paused() is True
|
||||
|
||||
def test_on_event_ignores_other_events(self):
|
||||
layer = SuspendLayer()
|
||||
layer.on_graph_start()
|
||||
initial_state = layer.is_paused()
|
||||
layer.on_event(object())
|
||||
assert layer.is_paused() is initial_state
|
||||
98
api/tests/unit_tests/core/app/layers/test_timeslice_layer.py
Normal file
98
api/tests/unit_tests/core/app/layers/test_timeslice_layer.py
Normal file
@ -0,0 +1,98 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||
from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import SchedulerCommand
|
||||
|
||||
|
||||
class TestTimeSliceLayer:
|
||||
def test_init_starts_scheduler_when_not_running(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = False
|
||||
|
||||
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
|
||||
_ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock()))
|
||||
|
||||
scheduler.start.assert_called_once()
|
||||
|
||||
def test_on_graph_start_adds_job_for_time_slice(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = True
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=3,
|
||||
)
|
||||
cfs_plan_scheduler = Mock(plan=plan)
|
||||
|
||||
with (
|
||||
patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler),
|
||||
patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid,
|
||||
):
|
||||
mock_uuid.return_value.hex = "job-1"
|
||||
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.schedule_id == "job-1"
|
||||
scheduler.add_job.assert_called_once()
|
||||
|
||||
def test_on_graph_end_removes_job(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = True
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=3,
|
||||
)
|
||||
cfs_plan_scheduler = Mock(plan=plan)
|
||||
|
||||
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
|
||||
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
|
||||
layer.schedule_id = "job-1"
|
||||
layer.on_graph_end(None)
|
||||
|
||||
scheduler.remove_job.assert_called_once_with("job-1")
|
||||
|
||||
def test_checker_job_removes_when_stopped(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = True
|
||||
cfs_plan_scheduler = Mock(plan=Mock())
|
||||
|
||||
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
|
||||
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
|
||||
layer.stopped = True
|
||||
layer._checker_job("job-1")
|
||||
|
||||
scheduler.remove_job.assert_called_once_with("job-1")
|
||||
|
||||
def test_checker_job_handles_resource_limit_without_command_channel(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = True
|
||||
cfs_plan_scheduler = Mock(plan=Mock())
|
||||
cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
|
||||
with (
|
||||
patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler),
|
||||
patch("core.app.layers.timeslice_layer.logger") as mock_logger,
|
||||
):
|
||||
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
|
||||
layer._checker_job("job-1")
|
||||
|
||||
scheduler.remove_job.assert_called_once_with("job-1")
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
def test_checker_job_sends_pause_command(self):
|
||||
scheduler = Mock()
|
||||
scheduler.running = True
|
||||
cfs_plan_scheduler = Mock(plan=Mock())
|
||||
cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
|
||||
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
|
||||
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
|
||||
layer.command_channel = Mock()
|
||||
layer._checker_job("job-1")
|
||||
|
||||
scheduler.remove_job.assert_called_once_with("job-1")
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
sent_command = layer.command_channel.send_command.call_args[0][0]
|
||||
assert isinstance(sent_command, GraphEngineCommand)
|
||||
assert sent_command.command_type == CommandType.PAUSE
|
||||
106
api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py
Normal file
106
api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py
Normal file
@ -0,0 +1,106 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
|
||||
|
||||
class TestTriggerPostLayer:
|
||||
def test_on_event_updates_trigger_log(self):
|
||||
trigger_log = SimpleNamespace(
|
||||
status=None,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
elapsed_time=None,
|
||||
total_tokens=None,
|
||||
finished_at=None,
|
||||
)
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={"answer": "ok"},
|
||||
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
|
||||
total_tokens=12,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory,
|
||||
patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls,
|
||||
patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime,
|
||||
):
|
||||
mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC)
|
||||
|
||||
session = Mock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = session
|
||||
|
||||
repo = Mock()
|
||||
repo.get_by_id.return_value = trigger_log
|
||||
mock_repo_cls.return_value = repo
|
||||
|
||||
layer = TriggerPostLayer(
|
||||
cfs_plan_scheduler_entity=Mock(),
|
||||
start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10),
|
||||
trigger_log_id="log-1",
|
||||
)
|
||||
layer.initialize(runtime_state, Mock())
|
||||
|
||||
layer.on_event(GraphRunSucceededEvent())
|
||||
|
||||
assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED
|
||||
assert trigger_log.workflow_run_id == "run-1"
|
||||
assert trigger_log.outputs is not None
|
||||
assert trigger_log.elapsed_time is not None
|
||||
assert trigger_log.total_tokens == 12
|
||||
assert trigger_log.finished_at is not None
|
||||
repo.update.assert_called_once_with(trigger_log)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_on_event_handles_missing_trigger_log(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory,
|
||||
patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls,
|
||||
patch("core.app.layers.trigger_post_layer.logger") as mock_logger,
|
||||
):
|
||||
session = Mock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = session
|
||||
|
||||
repo = Mock()
|
||||
repo.get_by_id.return_value = None
|
||||
mock_repo_cls.return_value = repo
|
||||
|
||||
layer = TriggerPostLayer(
|
||||
cfs_plan_scheduler_entity=Mock(),
|
||||
start_time=datetime(2026, 2, 20, tzinfo=UTC),
|
||||
trigger_log_id="missing",
|
||||
)
|
||||
layer.initialize(runtime_state, Mock())
|
||||
|
||||
layer.on_event(GraphRunFailedEvent(error="boom"))
|
||||
|
||||
mock_logger.exception.assert_called_once()
|
||||
session.commit.assert_not_called()
|
||||
|
||||
def test_on_event_ignores_non_status_events(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory:
|
||||
layer = TriggerPostLayer(
|
||||
cfs_plan_scheduler_entity=Mock(),
|
||||
start_time=datetime(2026, 2, 20, tzinfo=UTC),
|
||||
trigger_log_id="log-1",
|
||||
)
|
||||
layer.initialize(runtime_state, Mock())
|
||||
|
||||
layer.on_event(Mock())
|
||||
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
@ -0,0 +1,91 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.queue_entities import QueueErrorEvent
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.errors.error import QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
class TestBasedGenerateTaskPipeline:
|
||||
@pytest.fixture
|
||||
def pipeline(self):
|
||||
app_config = SimpleNamespace(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config)
|
||||
return BasedGenerateTaskPipeline(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=Mock(),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_error_to_desc_quota_exceeded(self, pipeline):
|
||||
message = pipeline._error_to_desc(QuotaExceededError())
|
||||
assert "quota" in message.lower()
|
||||
|
||||
def test_handle_error_wraps_invoke_authorization(self, pipeline):
|
||||
event = QueueErrorEvent(error=InvokeAuthorizationError())
|
||||
err = pipeline.handle_error(event=event)
|
||||
assert isinstance(err, InvokeAuthorizationError)
|
||||
assert str(err) == "Incorrect API key provided"
|
||||
|
||||
def test_handle_error_preserves_invoke_error(self, pipeline):
|
||||
event = QueueErrorEvent(error=InvokeError("bad"))
|
||||
err = pipeline.handle_error(event=event)
|
||||
assert err is event.error
|
||||
|
||||
def test_handle_error_updates_message_when_found(self, pipeline):
|
||||
event = QueueErrorEvent(error=ValueError("oops"))
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL, error=None)
|
||||
session = Mock()
|
||||
session.scalar.return_value = message
|
||||
|
||||
err = pipeline.handle_error(event=event, session=session, message_id="msg-1")
|
||||
|
||||
assert err is event.error
|
||||
assert message.status == MessageStatus.ERROR
|
||||
assert message.error == "oops"
|
||||
|
||||
def test_handle_error_returns_err_when_message_missing(self, pipeline):
|
||||
event = QueueErrorEvent(error=ValueError("oops"))
|
||||
session = Mock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
err = pipeline.handle_error(event=event, session=session, message_id="msg-1")
|
||||
|
||||
assert err is event.error
|
||||
|
||||
def test_error_to_stream_response_and_ping(self, pipeline):
|
||||
error_response = pipeline.error_to_stream_response(ValueError("boom"))
|
||||
ping_response = pipeline.ping_stream_response()
|
||||
|
||||
assert error_response.task_id == "task-1"
|
||||
assert ping_response.task_id == "task-1"
|
||||
|
||||
def test_handle_output_moderation_when_flagged(self, pipeline):
|
||||
handler = Mock()
|
||||
handler.moderation_completion.return_value = ("filtered", True)
|
||||
pipeline.output_moderation_handler = handler
|
||||
|
||||
result = pipeline.handle_output_moderation_when_task_finished("raw")
|
||||
|
||||
assert result == "filtered"
|
||||
handler.stop_thread.assert_called_once()
|
||||
assert pipeline.output_moderation_handler is None
|
||||
|
||||
def test_handle_output_moderation_when_not_flagged(self, pipeline):
|
||||
handler = Mock()
|
||||
handler.moderation_completion.return_value = ("safe", False)
|
||||
pipeline.output_moderation_handler = handler
|
||||
|
||||
result = pipeline.handle_output_moderation_when_task_finished("raw")
|
||||
|
||||
assert result is None
|
||||
handler.stop_thread.assert_called_once()
|
||||
assert pipeline.output_moderation_handler is None
|
||||
File diff suppressed because it is too large
Load Diff
11
api/tests/unit_tests/core/app/task_pipeline/test_exc.py
Normal file
11
api/tests/unit_tests/core/app/task_pipeline/test_exc.py
Normal file
@ -0,0 +1,11 @@
|
||||
from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class TestTaskPipelineExceptions:
|
||||
def test_record_not_found_error_message(self):
|
||||
err = RecordNotFoundError("Message", "msg-1")
|
||||
assert str(err) == "Message with id msg-1 not found"
|
||||
|
||||
def test_workflow_run_not_found_error_message(self):
|
||||
err = WorkflowRunNotFoundError("run-1")
|
||||
assert str(err) == "WorkflowRun with id run-1 not found"
|
||||
@ -1,12 +1,16 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager):
|
||||
"""Return MESSAGE_FILE directly from in-memory cache without opening a DB session."""
|
||||
message_cycle_manager._message_has_file.add("cached-message")
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
result = message_cycle_manager.get_message_event_type("cached-message")
|
||||
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization:
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
|
||||
def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager):
|
||||
"""Return None when completion entities are used for conversation naming.
|
||||
|
||||
Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity.
|
||||
Returns: None, indicating no name generation for completion apps.
|
||||
Side effects: None expected.
|
||||
"""
|
||||
|
||||
class DummyCompletion:
|
||||
pass
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion):
|
||||
message_cycle_manager._application_generate_entity = DummyCompletion()
|
||||
result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager):
|
||||
"""Spawn background generation thread for the first chat message."""
|
||||
message_cycle_manager._application_generate_entity.is_new_conversation = True
|
||||
message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True}
|
||||
flask_app = object()
|
||||
|
||||
class DummyTimer:
|
||||
def __init__(self, interval, function, args=None, kwargs=None):
|
||||
self.interval = interval
|
||||
self.function = function
|
||||
self.args = args or []
|
||||
self.kwargs = kwargs
|
||||
self.daemon = False
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.task_pipeline.message_cycle_manager.current_app",
|
||||
new=SimpleNamespace(_get_current_object=lambda: flask_app),
|
||||
),
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer),
|
||||
):
|
||||
thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello")
|
||||
|
||||
assert isinstance(thread, DummyTimer)
|
||||
assert thread.interval == 1
|
||||
assert thread.function == message_cycle_manager._generate_conversation_name_worker
|
||||
assert thread.started is True
|
||||
assert thread.daemon is True
|
||||
assert thread.kwargs["flask_app"] is flask_app
|
||||
assert thread.kwargs["conversation_id"] == "conv-1"
|
||||
assert thread.kwargs["query"] == "hello"
|
||||
assert message_cycle_manager._application_generate_entity.is_new_conversation is False
|
||||
|
||||
def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager):
|
||||
"""Skip thread creation when auto naming is disabled but still mark conversation as not new."""
|
||||
message_cycle_manager._application_generate_entity.is_new_conversation = True
|
||||
message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False}
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer:
|
||||
result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello")
|
||||
|
||||
assert result is None
|
||||
assert message_cycle_manager._application_generate_entity.is_new_conversation is False
|
||||
mock_timer.assert_not_called()
|
||||
|
||||
def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager):
|
||||
"""Return early when the conversation cannot be found."""
|
||||
flask_app = Flask(__name__)
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = None
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db:
|
||||
mock_db.session = db_session
|
||||
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello")
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
db_session.close.assert_not_called()
|
||||
|
||||
def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager):
|
||||
"""Return early when non-completion conversation has no app relation."""
|
||||
flask_app = Flask(__name__)
|
||||
conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id")
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = conversation
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db:
|
||||
mock_db.session = db_session
|
||||
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
db_session.close.assert_not_called()
|
||||
|
||||
def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager):
|
||||
"""Use cached conversation name when present and avoid LLM call."""
|
||||
flask_app = Flask(__name__)
|
||||
conversation = SimpleNamespace(
|
||||
mode=AppMode.CHAT,
|
||||
app=SimpleNamespace(tenant_id="tenant-1"),
|
||||
app_id="app-id",
|
||||
name="",
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = conversation
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
|
||||
):
|
||||
mock_db.session = db_session
|
||||
mock_redis.get.return_value = b"cached-title"
|
||||
|
||||
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
|
||||
|
||||
assert conversation.name == "cached-title"
|
||||
db_session.commit.assert_called_once()
|
||||
db_session.close.assert_called_once()
|
||||
mock_llm_generator.generate_conversation_name.assert_not_called()
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager):
|
||||
"""Generate conversation name and write it to redis cache on cache miss."""
|
||||
flask_app = Flask(__name__)
|
||||
conversation = SimpleNamespace(
|
||||
mode=AppMode.CHAT,
|
||||
app=SimpleNamespace(tenant_id="tenant-1"),
|
||||
app_id="app-id",
|
||||
name="",
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = conversation
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
|
||||
):
|
||||
mock_db.session = db_session
|
||||
mock_redis.get.return_value = None
|
||||
mock_llm_generator.generate_conversation_name.return_value = "generated-title"
|
||||
|
||||
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
|
||||
|
||||
assert conversation.name == "generated-title"
|
||||
db_session.commit.assert_called_once()
|
||||
db_session.close.assert_called_once()
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager):
|
||||
"""Fallback to truncated query when LLM generation fails."""
|
||||
flask_app = Flask(__name__)
|
||||
conversation = SimpleNamespace(
|
||||
mode=AppMode.CHAT,
|
||||
app=SimpleNamespace(tenant_id="tenant-1"),
|
||||
app_id="app-id",
|
||||
name="",
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.return_value = conversation
|
||||
long_query = "q" * 60
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger,
|
||||
):
|
||||
mock_db.session = db_session
|
||||
mock_redis.get.return_value = None
|
||||
mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed")
|
||||
mock_dify_config.DEBUG = True
|
||||
|
||||
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query)
|
||||
|
||||
assert conversation.name == (long_query[:47] + "...")
|
||||
db_session.commit.assert_called_once()
|
||||
db_session.close.assert_called_once()
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager):
|
||||
"""Populate task metadata from annotation reply events.
|
||||
|
||||
Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService.
|
||||
Returns: The fetched annotation object.
|
||||
Side effects: Updates metadata.annotation_reply with id and account name.
|
||||
"""
|
||||
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
|
||||
|
||||
annotation = SimpleNamespace(
|
||||
id="ann-1",
|
||||
account_id="acct-1",
|
||||
account=SimpleNamespace(name="Alice"),
|
||||
)
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service:
|
||||
mock_service.get_annotation_by_id.return_value = annotation
|
||||
|
||||
result = message_cycle_manager.handle_annotation_reply(
|
||||
QueueAnnotationReplyEvent(message_annotation_id="ann-1")
|
||||
)
|
||||
|
||||
assert result == annotation
|
||||
assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1"
|
||||
assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice"
|
||||
|
||||
def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager):
|
||||
"""Return None and keep metadata unchanged when annotation is not found."""
|
||||
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
|
||||
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service:
|
||||
mock_service.get_annotation_by_id.return_value = None
|
||||
|
||||
result = message_cycle_manager.handle_annotation_reply(
|
||||
QueueAnnotationReplyEvent(message_annotation_id="missing")
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert message_cycle_manager._task_state.metadata.annotation_reply is None
|
||||
|
||||
def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager):
|
||||
"""Merge retriever resources, deduplicate, and preserve ordering positions.
|
||||
|
||||
Args: message_cycle_manager with show_retrieve_source enabled and existing metadata.
|
||||
Returns: None.
|
||||
Side effects: Updates metadata.retriever_resources with unique items and positions.
|
||||
"""
|
||||
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(
|
||||
additional_features=SimpleNamespace(show_retrieve_source=True)
|
||||
)
|
||||
existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
|
||||
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing]))
|
||||
|
||||
duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
|
||||
new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2")
|
||||
|
||||
event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource])
|
||||
message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2
|
||||
assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1
|
||||
assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2
|
||||
|
||||
def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager):
|
||||
"""Build a stream response with a signed tool file URL.
|
||||
|
||||
Args: message_cycle_manager with mocked Session/db and sign_tool_file.
|
||||
Returns: MessageStreamResponse with signed url and belongs_to normalized to user.
|
||||
Side effects: Calls sign_tool_file for tool file ids.
|
||||
"""
|
||||
message_cycle_manager._application_generate_entity.task_id = "task-1"
|
||||
|
||||
message_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
type="image",
|
||||
belongs_to=None,
|
||||
url="tool://file.verylongextension",
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
session = Mock()
|
||||
session.scalar.return_value = message_file
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = session
|
||||
mock_sign.return_value = "signed-url"
|
||||
|
||||
response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1"))
|
||||
|
||||
assert response.url == "signed-url"
|
||||
assert response.belongs_to == "user"
|
||||
mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin")
|
||||
|
||||
def test_handle_retriever_resources_requires_features(self, message_cycle_manager):
|
||||
"""Raise when retriever resources are handled without feature config.
|
||||
|
||||
Args: message_cycle_manager with additional_features unset and empty metadata.
|
||||
Raises: ValueError when show_retrieve_source configuration is missing.
|
||||
"""
|
||||
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None)
|
||||
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[]))
|
||||
|
||||
def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager):
|
||||
"""Ignore null resource entries while preserving valid resources."""
|
||||
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(
|
||||
additional_features=SimpleNamespace(show_retrieve_source=True)
|
||||
)
|
||||
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[]))
|
||||
resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
|
||||
|
||||
message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource]))
|
||||
|
||||
assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1
|
||||
assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1
|
||||
|
||||
def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager):
|
||||
"""Use original URL when message file URL is already HTTP."""
|
||||
message_cycle_manager._application_generate_entity.task_id = "task-http"
|
||||
message_file = SimpleNamespace(
|
||||
id="file-http",
|
||||
type="image",
|
||||
belongs_to="assistant",
|
||||
url="http://example.com/pic.png",
|
||||
message_id="msg-http",
|
||||
)
|
||||
|
||||
session = Mock()
|
||||
session.scalar.return_value = message_file
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
response = message_cycle_manager.message_file_to_stream_response(
|
||||
SimpleNamespace(message_file_id="file-http")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.url == "http://example.com/pic.png"
|
||||
assert "msg-http" in message_cycle_manager._message_has_file
|
||||
|
||||
def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager):
|
||||
"""Default tool file extension to .bin when URL has no extension part."""
|
||||
message_cycle_manager._application_generate_entity.task_id = "task-bin"
|
||||
message_file = SimpleNamespace(
|
||||
id="file-bin",
|
||||
type="file",
|
||||
belongs_to="assistant",
|
||||
url="tool-file-id",
|
||||
message_id="msg-bin",
|
||||
)
|
||||
|
||||
session = Mock()
|
||||
session.scalar.return_value = message_file
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = session
|
||||
mock_sign.return_value = "signed-bin-url"
|
||||
|
||||
response = message_cycle_manager.message_file_to_stream_response(
|
||||
SimpleNamespace(message_file_id="file-bin")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.url == "signed-bin-url"
|
||||
mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin")
|
||||
|
||||
def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager):
|
||||
"""Return None when message file lookup does not find a record."""
|
||||
session = Mock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing"))
|
||||
|
||||
assert response is None
|
||||
|
||||
def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager):
|
||||
"""Include the provided replacement reason in the stream payload."""
|
||||
response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation")
|
||||
|
||||
assert response.answer == "replaced"
|
||||
assert response.reason == "moderation"
|
||||
|
||||
43
api/tests/unit_tests/core/app/workflow/test_file_runtime.py
Normal file
43
api/tests/unit_tests/core/app/workflow/test_file_runtime.py
Normal file
@ -0,0 +1,43 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime
|
||||
|
||||
|
||||
class TestDifyWorkflowFileRuntime:
|
||||
def test_runtime_properties_and_helpers(self, monkeypatch):
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files")
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal")
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret")
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123)
|
||||
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url")
|
||||
|
||||
runtime = DifyWorkflowFileRuntime()
|
||||
|
||||
assert runtime.files_url == "http://files"
|
||||
assert runtime.internal_files_url == "http://internal"
|
||||
assert runtime.secret_key == "secret"
|
||||
assert runtime.files_access_timeout == 123
|
||||
assert runtime.multimodal_send_format == "url"
|
||||
|
||||
with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get:
|
||||
mock_get.return_value = "response"
|
||||
assert runtime.http_get("http://example", follow_redirects=False) == "response"
|
||||
mock_get.assert_called_once_with("http://example", follow_redirects=False)
|
||||
|
||||
with patch("core.app.workflow.file_runtime.storage.load") as mock_load:
|
||||
mock_load.return_value = b"data"
|
||||
assert runtime.storage_load("path", stream=True) == b"data"
|
||||
mock_load.assert_called_once_with("path", stream=True)
|
||||
|
||||
with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign:
|
||||
mock_sign.return_value = "signed"
|
||||
assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed"
|
||||
mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False)
|
||||
|
||||
def test_bind_runtime_registers_instance(self):
|
||||
with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set:
|
||||
bind_dify_workflow_file_runtime()
|
||||
|
||||
mock_set.assert_called_once()
|
||||
runtime = mock_set.call_args[0][0]
|
||||
assert isinstance(runtime, DifyWorkflowFileRuntime)
|
||||
161
api/tests/unit_tests/core/app/workflow/test_node_factory.py
Normal file
161
api/tests/unit_tests/core/app/workflow/test_node_factory.py
Normal file
@ -0,0 +1,161 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = id
|
||||
self.config = config
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyCodeNode(DummyNode):
|
||||
@classmethod
|
||||
def default_code_providers(cls):
|
||||
return ()
|
||||
|
||||
|
||||
class DummyTemplateTransformNode(DummyNode):
|
||||
pass
|
||||
|
||||
|
||||
class DummyHttpRequestNode(DummyNode):
|
||||
pass
|
||||
|
||||
|
||||
class DummyKnowledgeRetrievalNode(DummyNode):
|
||||
pass
|
||||
|
||||
|
||||
class DummyDocumentExtractorNode(DummyNode):
|
||||
pass
|
||||
|
||||
|
||||
class TestDifyNodeFactory:
|
||||
@staticmethod
|
||||
def _stub_node_resolution(monkeypatch, node_class):
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.node_factory.resolve_workflow_node_class",
|
||||
lambda **_kwargs: node_class,
|
||||
)
|
||||
|
||||
def _factory(self, monkeypatch):
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100)
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u")
|
||||
monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key")
|
||||
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
return DifyNodeFactory(
|
||||
graph_init_params=SimpleNamespace(run_context=run_context),
|
||||
graph_runtime_state=SimpleNamespace(),
|
||||
)
|
||||
|
||||
def test_create_node_unknown_type(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
factory.create_node({"id": "node-1", "data": {"type": "unknown"}})
|
||||
|
||||
def test_create_node_missing_mapping(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}})
|
||||
|
||||
def test_create_node_missing_latest_class(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.node_factory.get_node_type_classes_mapping",
|
||||
lambda: {BuiltinNodeTypes.START: {"1": None}},
|
||||
)
|
||||
monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}})
|
||||
|
||||
def test_create_node_selects_versioned_class(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
selected_versions: list[tuple[str, str]] = []
|
||||
|
||||
class DummyNodeV2(DummyNode):
|
||||
pass
|
||||
|
||||
def _get_mapping():
|
||||
selected_versions.append(("snapshot", "called"))
|
||||
return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}}
|
||||
|
||||
monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}})
|
||||
|
||||
assert isinstance(node, DummyNodeV2)
|
||||
assert node.id == "node-1"
|
||||
assert selected_versions == [("snapshot", "called")]
|
||||
|
||||
def test_create_node_code_branch(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
self._stub_node_resolution(monkeypatch, DummyCodeNode)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}})
|
||||
|
||||
assert isinstance(node, DummyCodeNode)
|
||||
assert node.id == "node-1"
|
||||
|
||||
def test_create_node_template_transform_branch(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}})
|
||||
|
||||
assert isinstance(node, DummyTemplateTransformNode)
|
||||
assert "template_renderer" in node.kwargs
|
||||
|
||||
def test_create_node_http_request_branch(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
self._stub_node_resolution(monkeypatch, DummyHttpRequestNode)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}})
|
||||
|
||||
assert isinstance(node, DummyHttpRequestNode)
|
||||
assert "http_request_config" in node.kwargs
|
||||
|
||||
def test_create_node_knowledge_retrieval_branch(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}})
|
||||
|
||||
assert isinstance(node, DummyKnowledgeRetrievalNode)
|
||||
assert node.kwargs == {}
|
||||
|
||||
def test_create_node_document_extractor_branch(self, monkeypatch):
|
||||
factory = self._factory(monkeypatch)
|
||||
self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode)
|
||||
|
||||
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}})
|
||||
|
||||
assert isinstance(node, DummyDocumentExtractorNode)
|
||||
assert "unstructured_api_config" in node.kwargs
|
||||
@ -0,0 +1,209 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class TestObservabilityLayerExtras:
|
||||
def test_init_tracer_enabled_sets_tracer(self, monkeypatch):
|
||||
tracer = object()
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
assert layer._is_disabled is False
|
||||
assert layer._tracer is tracer
|
||||
|
||||
def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog):
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise RuntimeError("tracer init failed")
|
||||
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
assert layer._is_disabled is True
|
||||
assert layer._tracer is None
|
||||
assert "Failed to get OpenTelemetry tracer" in caplog.text
|
||||
|
||||
def test_init_tracer_disables_when_otel_disabled(self, monkeypatch):
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
assert layer._is_disabled is True
|
||||
|
||||
def test_get_parser_uses_registry_when_node_type_matches(self):
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL))
|
||||
|
||||
assert parser is layer._parsers[BuiltinNodeTypes.TOOL]
|
||||
|
||||
def test_get_parser_defaults_when_node_type_missing(self):
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
parser = layer._get_parser(SimpleNamespace(node_type=None))
|
||||
|
||||
assert parser is layer._default_parser
|
||||
|
||||
def test_on_graph_start_clears_contexts(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
|
||||
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
|
||||
def test_on_event_is_noop(self):
|
||||
layer = ObservabilityLayer()
|
||||
|
||||
layer.on_event(object())
|
||||
|
||||
def test_on_graph_end_clears_unfinished_contexts(self, caplog):
|
||||
layer = ObservabilityLayer()
|
||||
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
assert "node spans were not properly ended" in caplog.text
|
||||
|
||||
def test_on_node_run_start_skips_without_execution_id(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
layer._tracer = None
|
||||
|
||||
layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node"))
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
|
||||
def test_on_node_run_start_skips_when_disabled(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = True
|
||||
layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object())
|
||||
|
||||
layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node"))
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
|
||||
def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
calls: list[str] = []
|
||||
layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called"))
|
||||
|
||||
layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node"))
|
||||
|
||||
assert calls == []
|
||||
|
||||
def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise RuntimeError("start failed")
|
||||
|
||||
layer._tracer = SimpleNamespace(start_span=_raise)
|
||||
|
||||
layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node"))
|
||||
|
||||
assert "Failed to create OpenTelemetry span for node" in caplog.text
|
||||
|
||||
def test_on_node_run_end_without_context_noop(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
|
||||
layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None)
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
|
||||
def test_on_node_run_end_skips_when_disabled(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = True
|
||||
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
|
||||
|
||||
layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None)
|
||||
|
||||
assert "exec" in layer._node_contexts
|
||||
|
||||
def test_on_node_run_end_skips_without_execution_id(self):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
|
||||
layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None)
|
||||
|
||||
assert layer._node_contexts == {}
|
||||
|
||||
def test_on_node_run_end_calls_span_end(self, monkeypatch):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
ended: list[str] = []
|
||||
|
||||
class _Parser:
|
||||
def parse(self, **_kwargs):
|
||||
return None
|
||||
|
||||
span = SimpleNamespace(end=lambda: ended.append("ended"))
|
||||
layer._default_parser = _Parser()
|
||||
layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token")
|
||||
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None)
|
||||
|
||||
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
|
||||
layer.on_node_run_end(node, error=None)
|
||||
|
||||
assert ended == ["ended"]
|
||||
assert "exec" not in layer._node_contexts
|
||||
|
||||
def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
|
||||
class _Parser:
|
||||
def parse(self, **_kwargs):
|
||||
return None
|
||||
|
||||
layer._default_parser = _Parser()
|
||||
layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token")
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise RuntimeError("detach failed")
|
||||
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise)
|
||||
|
||||
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
|
||||
layer.on_node_run_end(node, error=None)
|
||||
|
||||
assert "Failed to detach OpenTelemetry token" in caplog.text
|
||||
assert "exec" not in layer._node_contexts
|
||||
|
||||
def test_on_node_run_start_and_end_creates_span(self, monkeypatch):
|
||||
layer = ObservabilityLayer()
|
||||
layer._is_disabled = False
|
||||
|
||||
span = SimpleNamespace(end=lambda: None)
|
||||
tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span)
|
||||
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object())
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object())
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token")
|
||||
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None)
|
||||
|
||||
layer._tracer = tracer
|
||||
|
||||
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
|
||||
|
||||
layer.on_node_run_start(node)
|
||||
assert "exec" in layer._node_contexts
|
||||
|
||||
layer.on_node_run_end(node, error=None)
|
||||
assert "exec" not in layer._node_contexts
|
||||
499
api/tests/unit_tests/core/app/workflow/test_persistence_layer.py
Normal file
499
api/tests/unit_tests/core/app/workflow/test_persistence_layer.py
Normal file
@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from dify_graph.entities.pause_reason import SchedulingPause
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from dify_graph.graph_events.graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _RepoRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.saved: list[object] = []
|
||||
self.saved_exec_data: list[object] = []
|
||||
|
||||
def save(self, entity):
|
||||
self.saved.append(entity)
|
||||
|
||||
def save_execution_data(self, entity):
|
||||
self.saved_exec_data.append(entity)
|
||||
|
||||
|
||||
def _naive_utc_now() -> datetime:
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def _make_layer(
|
||||
system_variable: SystemVariable | None = None,
|
||||
*,
|
||||
extras: dict | None = None,
|
||||
trace_manager: object | None = None,
|
||||
):
|
||||
system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id")
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0)
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
|
||||
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=SimpleNamespace(app_id="app", tenant_id="tenant"),
|
||||
inputs={"foo": "bar"},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=None,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras=extras or {},
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
workflow_info = PersistenceWorkflowInfo(
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_data={"nodes": [], "edges": []},
|
||||
)
|
||||
|
||||
workflow_execution_repo = _RepoRecorder()
|
||||
workflow_node_execution_repo = _RepoRecorder()
|
||||
|
||||
layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_info=workflow_info,
|
||||
workflow_execution_repository=workflow_execution_repo,
|
||||
workflow_node_execution_repository=workflow_node_execution_repo,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
layer.initialize(read_only_state, command_channel=None)
|
||||
|
||||
return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state
|
||||
|
||||
|
||||
class TestWorkflowPersistenceLayer:
|
||||
def test_on_graph_start_resets_state(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
layer._workflow_execution = object()
|
||||
layer._node_execution_cache["cached"] = object()
|
||||
layer._node_snapshots["cached"] = object()
|
||||
layer._node_sequence = 9
|
||||
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer._workflow_execution is None
|
||||
assert layer._node_execution_cache == {}
|
||||
assert layer._node_snapshots == {}
|
||||
assert layer._node_sequence == 0
|
||||
|
||||
def test_get_execution_id_requires_system_variable(self):
|
||||
system_variable = SystemVariable(workflow_execution_id=None)
|
||||
layer, _, _, _ = _make_layer(system_variable)
|
||||
|
||||
with pytest.raises(ValueError, match="workflow_execution_id must be provided"):
|
||||
layer._get_execution_id()
|
||||
|
||||
def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch):
|
||||
layer, _, _, _ = _make_layer()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.workflow_entry.WorkflowEntry.handle_special_values",
|
||||
lambda inputs: inputs,
|
||||
)
|
||||
|
||||
inputs = layer._prepare_workflow_inputs()
|
||||
|
||||
assert "sys.conversation_id" not in inputs
|
||||
assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id"
|
||||
|
||||
def test_fail_running_node_executions_marks_failed(self):
|
||||
layer, _, node_repo, _ = _make_layer()
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="exec-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
index=1,
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
created_at=_naive_utc_now(),
|
||||
)
|
||||
layer._node_execution_cache[execution.id] = execution
|
||||
|
||||
layer._fail_running_node_executions(error_message="boom")
|
||||
|
||||
assert execution.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert node_repo.saved
|
||||
|
||||
def test_handle_graph_run_started_saves_execution(self):
|
||||
layer, exec_repo, _, _ = _make_layer()
|
||||
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
assert exec_repo.saved
|
||||
|
||||
def test_handle_graph_run_succeeded_updates_execution(self):
|
||||
layer, exec_repo, _, runtime_state = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
runtime_state.total_tokens = 3
|
||||
runtime_state.node_run_steps = 2
|
||||
runtime_state.outputs = {"out": "v"}
|
||||
|
||||
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
|
||||
saved = exec_repo.saved[-1]
|
||||
assert saved.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert saved.total_tokens == 3
|
||||
assert saved.total_steps == 2
|
||||
|
||||
def test_handle_graph_run_partial_succeeded_updates_execution(self):
|
||||
layer, exec_repo, _, runtime_state = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
runtime_state.total_tokens = 5
|
||||
runtime_state.node_run_steps = 4
|
||||
runtime_state._graph_execution = SimpleNamespace(exceptions_count=2)
|
||||
|
||||
layer._handle_graph_run_partial_succeeded(
|
||||
GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2)
|
||||
)
|
||||
|
||||
saved = exec_repo.saved[-1]
|
||||
assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
assert saved.exceptions_count == 2
|
||||
assert saved.total_tokens == 5
|
||||
|
||||
def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self):
|
||||
trace_tasks: list[object] = []
|
||||
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
|
||||
layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager)
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
running = WorkflowNodeExecution(
|
||||
id="node-exec",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
index=1,
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
created_at=_naive_utc_now(),
|
||||
)
|
||||
layer._node_execution_cache[running.id] = running
|
||||
|
||||
layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1))
|
||||
|
||||
assert node_repo.saved
|
||||
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
|
||||
assert trace_tasks
|
||||
|
||||
def test_handle_graph_run_aborted_sets_status(self):
|
||||
layer, exec_repo, _, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={}))
|
||||
|
||||
saved = exec_repo.saved[-1]
|
||||
assert saved.status == WorkflowExecutionStatus.STOPPED
|
||||
assert saved.error_message
|
||||
|
||||
def test_handle_graph_run_paused_updates_outputs(self):
|
||||
layer, exec_repo, _, runtime_state = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
runtime_state.total_tokens = 7
|
||||
runtime_state.node_run_steps = 5
|
||||
|
||||
layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True}))
|
||||
|
||||
saved = exec_repo.saved[-1]
|
||||
assert saved.status == WorkflowExecutionStatus.PAUSED
|
||||
assert saved.outputs == {"pause": True}
|
||||
assert saved.finished_at is None
|
||||
|
||||
def test_handle_node_started_and_retry(self):
|
||||
layer, _, node_repo, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_title="Start",
|
||||
start_at=_naive_utc_now(),
|
||||
predecessor_node_id="prev",
|
||||
in_iteration_id="iter",
|
||||
in_loop_id="loop",
|
||||
)
|
||||
layer._handle_node_started(start_event)
|
||||
|
||||
assert node_repo.saved
|
||||
assert "exec" in layer._node_execution_cache
|
||||
assert layer._node_snapshots["exec"].node_id == "node"
|
||||
|
||||
retry_event = NodeRunRetryEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_title="Start",
|
||||
start_at=_naive_utc_now(),
|
||||
error="retry",
|
||||
retry_index=1,
|
||||
)
|
||||
layer._handle_node_retry(retry_event)
|
||||
assert node_repo.saved_exec_data
|
||||
|
||||
def test_handle_node_result_events_update_execution(self):
|
||||
layer, _, node_repo, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
node_title="LLM",
|
||||
start_at=_naive_utc_now(),
|
||||
)
|
||||
layer._handle_node_started(start_event)
|
||||
|
||||
result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={})
|
||||
success_event = NodeRunSucceededEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=_naive_utc_now(),
|
||||
node_run_result=result,
|
||||
)
|
||||
layer._handle_node_succeeded(success_event)
|
||||
|
||||
failed_event = NodeRunFailedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=_naive_utc_now(),
|
||||
error="boom",
|
||||
node_run_result=result,
|
||||
)
|
||||
layer._handle_node_failed(failed_event)
|
||||
|
||||
exception_event = NodeRunExceptionEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=_naive_utc_now(),
|
||||
error="err",
|
||||
node_run_result=result,
|
||||
)
|
||||
layer._handle_node_exception(exception_event)
|
||||
|
||||
assert node_repo.saved_exec_data
|
||||
|
||||
def test_handle_node_pause_requested_skips_outputs(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
start_event = NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
node_title="LLM",
|
||||
start_at=_naive_utc_now(),
|
||||
)
|
||||
layer._handle_node_started(start_event)
|
||||
|
||||
domain_execution = layer._node_execution_cache["exec"]
|
||||
domain_execution.inputs = {"old": True}
|
||||
|
||||
result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={})
|
||||
pause_event = NodeRunPauseRequestedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
reason=SchedulingPause(message="pause"),
|
||||
node_run_result=result,
|
||||
)
|
||||
layer._handle_node_pause_requested(pause_event)
|
||||
|
||||
assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED
|
||||
assert domain_execution.inputs == {"old": True}
|
||||
|
||||
def test_get_node_execution_raises_for_missing(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
with pytest.raises(ValueError, match="Node execution not found"):
|
||||
layer._get_node_execution("missing")
|
||||
|
||||
def test_get_workflow_execution_raises_when_uninitialized(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow execution not initialized"):
|
||||
layer._get_workflow_execution()
|
||||
|
||||
def test_next_node_sequence_increments(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
assert layer._next_node_sequence() == 1
|
||||
assert layer._next_node_sequence() == 2
|
||||
|
||||
def test_on_graph_end_is_noop(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
|
||||
assert layer.on_graph_end(error=None) is None
|
||||
|
||||
def test_on_event_dispatches_to_all_known_handlers(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
called: list[str] = []
|
||||
|
||||
def _record(name: str):
|
||||
def _handler(*_args, **_kwargs):
|
||||
called.append(name)
|
||||
|
||||
return _handler
|
||||
|
||||
layer._handle_graph_run_started = _record("started")
|
||||
layer._handle_graph_run_succeeded = _record("succeeded")
|
||||
layer._handle_graph_run_partial_succeeded = _record("partial")
|
||||
layer._handle_graph_run_failed = _record("failed")
|
||||
layer._handle_graph_run_aborted = _record("aborted")
|
||||
layer._handle_graph_run_paused = _record("paused")
|
||||
layer._handle_node_started = _record("node_started")
|
||||
layer._handle_node_retry = _record("node_retry")
|
||||
layer._handle_node_succeeded = _record("node_succeeded")
|
||||
layer._handle_node_failed = _record("node_failed")
|
||||
layer._handle_node_exception = _record("node_exception")
|
||||
layer._handle_node_pause_requested = _record("node_paused")
|
||||
|
||||
node_result = NodeRunResult()
|
||||
now = _naive_utc_now()
|
||||
events = [
|
||||
GraphRunStartedEvent(),
|
||||
GraphRunSucceededEvent(outputs={"ok": True}),
|
||||
GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1),
|
||||
GraphRunFailedEvent(error="boom", exceptions_count=1),
|
||||
GraphRunAbortedEvent(reason="stop", outputs={"x": 1}),
|
||||
GraphRunPausedEvent(outputs={"pause": True}),
|
||||
NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_title="Start",
|
||||
start_at=now,
|
||||
),
|
||||
NodeRunRetryEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_title="Start",
|
||||
start_at=now,
|
||||
error="retry",
|
||||
retry_index=1,
|
||||
),
|
||||
NodeRunSucceededEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
start_at=now,
|
||||
node_run_result=node_result,
|
||||
),
|
||||
NodeRunFailedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
start_at=now,
|
||||
error="failed",
|
||||
node_run_result=node_result,
|
||||
),
|
||||
NodeRunExceptionEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
start_at=now,
|
||||
error="error",
|
||||
node_run_result=node_result,
|
||||
),
|
||||
NodeRunPauseRequestedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
reason=SchedulingPause(message="pause"),
|
||||
node_run_result=node_result,
|
||||
),
|
||||
]
|
||||
expected_order = [
|
||||
"started",
|
||||
"succeeded",
|
||||
"partial",
|
||||
"failed",
|
||||
"aborted",
|
||||
"paused",
|
||||
"node_started",
|
||||
"node_retry",
|
||||
"node_succeeded",
|
||||
"node_failed",
|
||||
"node_exception",
|
||||
"node_paused",
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
assert called == expected_order
|
||||
|
||||
def test_on_event_dispatches_retry_before_started_for_retry_event(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
called: list[str] = []
|
||||
|
||||
def _record(name: str):
|
||||
def _handler(*_args, **_kwargs):
|
||||
called.append(name)
|
||||
|
||||
return _handler
|
||||
|
||||
layer._handle_node_started = _record("node_started")
|
||||
layer._handle_node_retry = _record("node_retry")
|
||||
|
||||
layer.on_event(
|
||||
NodeRunRetryEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_title="Start",
|
||||
start_at=_naive_utc_now(),
|
||||
error="retry",
|
||||
retry_index=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert called == ["node_retry"]
|
||||
|
||||
def test_enqueue_trace_task_skips_when_disabled(self):
|
||||
trace_tasks: list[object] = []
|
||||
layer, exec_repo, _, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
assert exec_repo.saved
|
||||
assert not trace_tasks
|
||||
@ -110,7 +110,7 @@ def test_encrypt_tool_parameters():
|
||||
assert encrypted["plain"] == "x"
|
||||
|
||||
|
||||
def test_decrypt_tool_parameters_cache_hit_and_miss():
|
||||
def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch):
|
||||
manager = _build_manager()
|
||||
|
||||
with (
|
||||
@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache():
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
|
||||
def test_configuration_manager_decrypt_suppresses_errors():
|
||||
def test_configuration_manager_decrypt_suppresses_errors(monkeypatch):
|
||||
manager = _build_manager()
|
||||
with (
|
||||
patch.object(ToolParameterCache, "get", return_value=None),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user