test: unit test cases for sub modules in core.app (except core.app.apps) (#32476)

This commit is contained in:
Rajat Agarwal 2026-03-24 23:43:28 +05:30 committed by GitHub
parent e873cea99e
commit 36cc1bf025
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 5772 additions and 13 deletions

View File

@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent
class SuspendLayer(GraphEngineLayer):
""" """
def __init__(self) -> None:
super().__init__()
self._paused = False
def on_graph_start(self):
pass
self._paused = False
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
"""
if isinstance(event, GraphRunPausedEvent):
pass
self._paused = True
def on_graph_end(self, error: Exception | None):
""" """
pass
self._paused = False
def is_paused(self) -> bool:
return self._paused

View File

@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return

View File

@ -68,7 +68,7 @@ def init_tool_node(config: dict):
return node
def test_tool_variable_invoke():
def test_tool_variable_invoke(monkeypatch):
node = init_tool_node(
config={
"id": "1",
@ -103,7 +103,7 @@ def test_tool_variable_invoke():
assert item.node_run_result.outputs.get("text") is not None
def test_tool_mixed_invoke():
def test_tool_mixed_invoke(monkeypatch):
node = init_tool_node(
config={
"id": "1",

View File

@ -0,0 +1,227 @@
from unittest.mock import MagicMock
import pytest
# Module under test
from core.app.app_config.common import parameters_mapping
class TestGetParametersFromFeatureDict:
"""Test suite for get_parameters_from_feature_dict"""
@pytest.fixture
def mock_config(self, monkeypatch):
"""Mock dify_config values"""
mock = MagicMock()
mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1
mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2
mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3
mock.UPLOAD_FILE_SIZE_LIMIT = 4
mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5
monkeypatch.setattr(parameters_mapping, "dify_config", mock)
return mock
@pytest.fixture
def mock_default_file_limits(self, monkeypatch):
"""Mock DEFAULT_FILE_NUMBER_LIMITS constant"""
monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99)
return 99
@pytest.fixture
def minimal_inputs(self):
return {}, []
@pytest.mark.parametrize(
("feature_key", "expected_default"),
[
("suggested_questions", []),
("suggested_questions_after_answer", {"enabled": False}),
("speech_to_text", {"enabled": False}),
("text_to_speech", {"enabled": False}),
("retriever_resource", {"enabled": False}),
("annotation_reply", {"enabled": False}),
("more_like_this", {"enabled": False}),
(
"sensitive_word_avoidance",
{"enabled": False, "type": "", "configs": []},
),
],
)
def test_defaults_when_key_missing(
self,
feature_key,
expected_default,
mock_config,
mock_default_file_limits,
):
# Arrange
features = {}
user_input = []
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=user_input,
)
# Assert
assert result[feature_key] == expected_default
def test_opening_statement_present(self, mock_config, mock_default_file_limits):
# Arrange
features = {"opening_statement": "Hello"}
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=[],
)
# Assert
assert result["opening_statement"] == "Hello"
def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits):
# Arrange
features = {}
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=[],
)
# Assert
assert result["opening_statement"] is None
def test_all_features_provided(self, mock_config, mock_default_file_limits):
# Arrange
features = {
"opening_statement": "Hi",
"suggested_questions": ["Q1"],
"suggested_questions_after_answer": {"enabled": True},
"speech_to_text": {"enabled": True},
"text_to_speech": {"enabled": True},
"retriever_resource": {"enabled": True},
"annotation_reply": {"enabled": True},
"more_like_this": {"enabled": True},
"sensitive_word_avoidance": {
"enabled": True,
"type": "strict",
"configs": ["a"],
},
"file_upload": {
"image": {
"enabled": True,
"number_limits": 10,
"detail": "low",
"transfer_methods": ["local_file"],
}
},
}
user_input = [{"name": "field1"}]
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=user_input,
)
# Assert
for key in features:
assert result[key] == features[key]
assert result["user_input_form"] == user_input
def test_file_upload_default_structure(self, mock_config, mock_default_file_limits):
# Arrange
features = {}
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=[],
)
# Assert
file_upload = result["file_upload"]
assert file_upload["image"]["enabled"] is False
assert file_upload["image"]["number_limits"] == 99
assert file_upload["image"]["detail"] == "high"
assert "remote_url" in file_upload["image"]["transfer_methods"]
assert "local_file" in file_upload["image"]["transfer_methods"]
def test_system_parameters_from_config(self, mock_config, mock_default_file_limits):
# Arrange
features = {}
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=[],
)
# Assert
system_params = result["system_parameters"]
assert system_params["image_file_size_limit"] == 1
assert system_params["video_file_size_limit"] == 2
assert system_params["audio_file_size_limit"] == 3
assert system_params["file_size_limit"] == 4
assert system_params["workflow_file_upload_limit"] == 5
@pytest.mark.parametrize(
("features_dict", "user_input_form"),
[
(None, []),
([], []),
("invalid", []),
],
)
def test_invalid_features_dict_type_raises(self, features_dict, user_input_form):
# Act & Assert
with pytest.raises(AttributeError):
parameters_mapping.get_parameters_from_feature_dict(
features_dict=features_dict,
user_input_form=user_input_form,
)
@pytest.mark.parametrize(
"user_input_form",
[None, "invalid", 123],
)
def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form):
# Arrange
features = {}
# Act
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=user_input_form,
)
# Assert
assert result["user_input_form"] == user_input_form
def test_empty_user_input_form(self, mock_config, mock_default_file_limits):
features = {}
user_input = []
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=user_input,
)
assert result["user_input_form"] == []
def test_feature_values_none(self, mock_config, mock_default_file_limits):
features = {
"suggested_questions": None,
"speech_to_text": None,
}
result = parameters_mapping.get_parameters_from_feature_dict(
features_dict=features,
user_input_form=[],
)
assert result["suggested_questions"] is None
assert result["speech_to_text"] is None

View File

@ -0,0 +1,202 @@
from unittest.mock import MagicMock
import pytest
from core.app.app_config.common.sensitive_word_avoidance.manager import (
SensitiveWordAvoidanceConfigManager,
)
class TestSensitiveWordAvoidanceConfigManagerConvert:
"""Tests for convert classmethod"""
@pytest.mark.parametrize(
"config",
[
{},
{"sensitive_word_avoidance": None},
{"sensitive_word_avoidance": {}},
{"sensitive_word_avoidance": {"enabled": False}},
],
)
def test_convert_returns_none_when_disabled_or_missing(self, config):
# Act
result = SensitiveWordAvoidanceConfigManager.convert(config)
# Assert
assert result is None
def test_convert_returns_entity_when_enabled(self, mocker):
# Arrange
mock_entity = MagicMock()
mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity",
return_value=mock_entity,
)
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": "mock_type",
"config": {"key": "value"},
}
}
# Act
result = SensitiveWordAvoidanceConfigManager.convert(config)
# Assert
assert result == mock_entity
def test_convert_enabled_without_type_or_config(self, mocker):
# Arrange
mock_entity = MagicMock()
patched = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity",
return_value=mock_entity,
)
config = {"sensitive_word_avoidance": {"enabled": True}}
# Act
result = SensitiveWordAvoidanceConfigManager.convert(config)
# Assert
patched.assert_called_once_with(type=None, config={})
assert result == mock_entity
class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults:
"""Tests for validate_and_set_defaults classmethod"""
@pytest.fixture
def base_config(self):
return {}
def test_validate_sets_default_when_missing(self, base_config):
# Act
config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id="tenant1", config=base_config.copy()
)
# Assert
assert config["sensitive_word_avoidance"]["enabled"] is False
assert fields == ["sensitive_word_avoidance"]
def test_validate_raises_when_not_dict(self):
config = {"sensitive_word_avoidance": "invalid"}
with pytest.raises(ValueError, match="must be of dict type"):
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
@pytest.mark.parametrize(
"config",
[
{"sensitive_word_avoidance": {"enabled": False}},
{"sensitive_word_avoidance": {"enabled": None}},
{"sensitive_word_avoidance": {}},
],
)
def test_validate_disables_when_enabled_false_or_missing(self, config):
# Act
result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id="tenant1", config=config
)
# Assert
assert result_config["sensitive_word_avoidance"]["enabled"] is False
def test_validate_raises_when_enabled_true_without_type(self):
config = {"sensitive_word_avoidance": {"enabled": True}}
with pytest.raises(ValueError, match="type is required"):
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
def test_validate_raises_when_type_not_string(self):
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": 123,
}
}
with pytest.raises(ValueError, match="must be a string"):
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
def test_validate_raises_when_config_not_dict(self):
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": "mock_type",
"config": "invalid",
}
}
with pytest.raises(ValueError, match="must be a dict"):
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
def test_validate_calls_moderation_factory(self, mocker):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
)
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": "mock_type",
"config": {"k": "v"},
}
}
# Act
result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id="tenant1", config=config
)
# Assert
mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"})
assert result_config["sensitive_word_avoidance"]["enabled"] is True
assert fields == ["sensitive_word_avoidance"]
def test_validate_sets_empty_dict_when_config_none(self, mocker):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
)
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": "mock_type",
"config": None,
}
}
# Act
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
# Assert
mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={})
def test_validate_only_structure_validate_skips_factory(self, mocker):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
)
config = {
"sensitive_word_avoidance": {
"enabled": True,
"type": "mock_type",
"config": {"k": "v"},
}
}
# Act
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id="tenant1", config=config, only_structure_validate=True
)
# Assert
mock_validate.assert_not_called()

View File

@ -0,0 +1,236 @@
from unittest.mock import MagicMock
import pytest
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
class TestAgentConfigManagerConvert:
@pytest.fixture
def base_config(self):
return {
"agent_mode": {
"enabled": True,
"strategy": "cot",
"tools": [],
},
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "completion",
},
}
def test_convert_returns_none_when_agent_mode_missing(self):
config = {"model": {"provider": "openai", "name": "gpt-4"}}
result = AgentConfigManager.convert(config)
assert result is None
@pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}])
def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config):
config = base_config.copy()
config["agent_mode"] = agent_mode_value
result = AgentConfigManager.convert(config)
assert result is None
@pytest.mark.parametrize(
("strategy_input", "expected_enum"),
[
("function_call", "FUNCTION_CALLING"),
("cot", "CHAIN_OF_THOUGHT"),
("react", "CHAIN_OF_THOUGHT"),
],
)
def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config):
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": strategy_input,
"tools": [],
}
result = AgentConfigManager.convert(config)
assert result is not None
assert result.strategy.name == expected_enum
def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config):
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "unknown_strategy",
"tools": [],
}
config["model"]["provider"] = "openai"
result = AgentConfigManager.convert(config)
assert result.strategy.name == "FUNCTION_CALLING"
def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config):
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "unknown_strategy",
"tools": [],
}
config["model"]["provider"] = "anthropic"
result = AgentConfigManager.convert(config)
assert result.strategy.name == "CHAIN_OF_THOUGHT"
def test_convert_skips_disabled_tools(self, mocker, base_config):
# Patch AgentEntity to bypass pydantic validation
mock_agent_entity = mocker.patch(
"core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity",
return_value=MagicMock(),
)
mock_validate = mocker.patch(
"core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate",
return_value={
"provider_type": "type2",
"provider_id": "id2",
"tool_name": "tool2",
"tool_parameters": {},
"credential_id": None,
},
)
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "cot",
"tools": [
{
"provider_type": "type1",
"provider_id": "id1",
"tool_name": "tool1",
"enabled": False,
},
{
"provider_type": "type2",
"provider_id": "id2",
"tool_name": "tool2",
"enabled": True,
"extra_key": "x",
},
],
}
AgentConfigManager.convert(config)
mock_validate.assert_called_once()
mock_agent_entity.assert_called_once()
def test_convert_tool_requires_minimum_keys(self, mocker, base_config):
mock_validate = mocker.patch(
"core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate",
return_value=MagicMock(),
)
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "cot",
"tools": [
{"a": 1, "b": 2}, # insufficient keys
],
}
result = AgentConfigManager.convert(config)
assert result is not None
assert result.tools == []
mock_validate.assert_not_called()
def test_convert_completion_mode_prompt_defaults(self, base_config):
config = base_config.copy()
config["agent_mode"]["prompt"] = {}
config["model"]["mode"] = "completion"
result = AgentConfigManager.convert(config)
assert result is not None
assert result.prompt.first_prompt is not None
assert result.prompt.next_iteration is not None
def test_convert_chat_mode_prompt_defaults(self, base_config):
config = base_config.copy()
config["agent_mode"]["prompt"] = {}
config["model"]["mode"] = "chat"
result = AgentConfigManager.convert(config)
assert result is not None
assert result.prompt.first_prompt is not None
assert result.prompt.next_iteration is not None
def test_convert_router_strategy_returns_none(self, base_config):
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "router",
"tools": [],
}
result = AgentConfigManager.convert(config)
assert result is None
def test_convert_react_router_strategy_returns_none(self, base_config):
config = base_config.copy()
config["agent_mode"] = {
"enabled": True,
"strategy": "react_router",
"tools": [],
}
result = AgentConfigManager.convert(config)
assert result is None
def test_convert_max_iteration_default(self, base_config):
config = base_config.copy()
config["agent_mode"].pop("max_iteration", None)
result = AgentConfigManager.convert(config)
assert result.max_iteration == 10
def test_convert_custom_max_iteration(self, base_config):
config = base_config.copy()
config["agent_mode"]["max_iteration"] = 25
result = AgentConfigManager.convert(config)
assert result.max_iteration == 25
def test_convert_missing_model_raises_key_error(self, base_config):
config = base_config.copy()
del config["model"]
with pytest.raises(KeyError):
AgentConfigManager.convert(config)
@pytest.mark.parametrize(
("invalid_config", "should_raise"),
[
(None, True),
(123, True),
("", False),
([], False),
],
)
def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise):
if should_raise:
with pytest.raises(TypeError):
AgentConfigManager.convert(invalid_config) # type: ignore
else:
result = AgentConfigManager.convert(invalid_config) # type: ignore
assert result is None

View File

@ -0,0 +1,319 @@
import uuid
from unittest.mock import MagicMock
import pytest
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode
# ==============================
# Fixtures
# ==============================
@pytest.fixture
def valid_uuid():
return str(uuid.uuid4())
@pytest.fixture
def base_config(valid_uuid):
return {
"dataset_configs": {
"retrieval_model": "multiple",
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
},
}
}
@pytest.fixture
def mock_dataset_service(mocker, valid_uuid):
mock_dataset = MagicMock()
mock_dataset.tenant_id = "tenant1"
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=mock_dataset,
)
# ==============================
# convert tests
# ==============================
class TestDatasetConfigManagerConvert:
def test_convert_returns_none_when_no_datasets(self):
config = {"dataset_configs": {"datasets": {"datasets": []}}}
result = DatasetConfigManager.convert(config)
assert result is None
def test_convert_single_retrieval(self, valid_uuid):
config = {
"dataset_query_variable": "query",
"dataset_configs": {
"retrieval_model": "single",
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
},
},
}
result = DatasetConfigManager.convert(config)
assert result is not None
assert result.dataset_ids == [valid_uuid]
assert result.retrieve_config.query_variable == "query"
def test_convert_single_with_metadata_configs(self, valid_uuid, mocker):
mock_retrieve_config = MagicMock()
mock_entity = MagicMock()
mock_entity.dataset_ids = [valid_uuid]
mock_entity.retrieve_config = mock_retrieve_config
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig",
return_value={"mock": "model"},
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition",
return_value={"mock": "condition"},
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity",
return_value=mock_retrieve_config,
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity",
return_value=mock_entity,
)
config = {
"dataset_query_variable": "query",
"dataset_configs": {
"retrieval_model": "single",
"metadata_filtering_mode": "manual",
"metadata_model_config": {"any": "value"},
"metadata_filtering_conditions": {"any": "value"},
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
},
},
}
result = DatasetConfigManager.convert(config)
assert result.dataset_ids == [valid_uuid]
assert result.retrieve_config is mock_retrieve_config
def test_convert_multiple_defaults(self, valid_uuid):
config = {
"dataset_configs": {
"retrieval_model": "multiple",
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
},
}
}
result = DatasetConfigManager.convert(config)
assert result.retrieve_config.top_k == 4
assert result.retrieve_config.score_threshold is None
assert result.retrieve_config.reranking_enabled is True
def test_convert_agent_mode_disabled_tool(self, valid_uuid):
config = {
"agent_mode": {
"enabled": True,
"tools": [{"dataset": {"id": valid_uuid, "enabled": False}}],
}
}
result = DatasetConfigManager.convert(config)
assert result is None
def test_convert_dataset_configs_none(self):
config = {"dataset_configs": None}
with pytest.raises(TypeError):
DatasetConfigManager.convert(config)
def test_convert_agent_mode_old_style_old_format(self, valid_uuid):
config = {
"agent_mode": {
"enabled": True,
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
}
}
result = DatasetConfigManager.convert(config)
assert result.dataset_ids == [valid_uuid]
assert result.retrieve_config.query_variable is None
def test_convert_multiple_with_score_threshold(self, valid_uuid):
config = {
"dataset_query_variable": "query",
"dataset_configs": {
"retrieval_model": "multiple",
"top_k": 5,
"score_threshold": 0.8,
"score_threshold_enabled": True,
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
},
},
}
result = DatasetConfigManager.convert(config)
assert result.retrieve_config.top_k == 5
assert result.retrieve_config.score_threshold == 0.8
@pytest.mark.parametrize(
"dataset_entry",
[
{},
{"invalid": {}},
{"dataset": {"id": None, "enabled": True}},
{"dataset": {"id": "", "enabled": False}},
],
)
def test_convert_ignores_invalid_dataset_entries(self, dataset_entry):
config = {
"dataset_configs": {
"retrieval_model": "multiple",
"datasets": {"strategy": "router", "datasets": [dataset_entry]},
}
}
result = DatasetConfigManager.convert(config)
assert result is None
def test_convert_agent_mode_old_style(self, valid_uuid):
config = {
"agent_mode": {
"enabled": True,
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
}
}
result = DatasetConfigManager.convert(config)
assert result.dataset_ids == [valid_uuid]
# ==============================
# validate_and_set_defaults tests
# ==============================
class TestValidateAndSetDefaults:
def test_validate_sets_defaults(self):
config = {}
updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config)
assert "dataset_configs" in updated
assert updated["dataset_configs"]["retrieval_model"] == "single"
assert isinstance(fields, list)
def test_validate_raises_when_dataset_configs_not_dict(self):
config = {"dataset_configs": "invalid"}
with pytest.raises(AttributeError):
DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config)
def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid):
config = {
"dataset_configs": {
"datasets": {
"strategy": "router",
"datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}],
}
}
}
with pytest.raises(ValueError):
DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config)
# ==============================
# extract_dataset_config_for_legacy_compatibility tests
# ==============================
class TestExtractDatasetConfig:
def test_extract_sets_defaults(self):
config = {}
result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
assert "agent_mode" in result
assert result["agent_mode"]["enabled"] is False
assert result["agent_mode"]["tools"] == []
def test_extract_invalid_agent_mode_type(self):
config = {"agent_mode": "invalid"}
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_invalid_enabled_type(self):
config = {"agent_mode": {"enabled": "yes"}}
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_invalid_tools_type(self):
config = {"agent_mode": {"enabled": True, "tools": "invalid"}}
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_invalid_uuid(self, mocker):
invalid_uuid = "not-a-uuid"
config = {
"agent_mode": {
"enabled": True,
"strategy": PlanningStrategy.ROUTER,
"tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}],
}
}
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_dataset_not_exists(self, valid_uuid, mocker):
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=None,
)
config = {
"agent_mode": {
"enabled": True,
"strategy": PlanningStrategy.ROUTER,
"tools": [{"dataset": {"id": valid_uuid, "enabled": True}}],
}
}
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
# ==============================
# is_dataset_exists tests
# ==============================
class TestIsDatasetExists:
def test_dataset_exists_true(self, mocker, valid_uuid):
mock_dataset = MagicMock()
mock_dataset.tenant_id = "tenant1"
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=mock_dataset,
)
assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid):
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=None,
)
assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid):
mock_dataset = MagicMock()
mock_dataset.tenant_id = "other"
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=mock_dataset,
)
assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)

View File

@ -0,0 +1,234 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.entities.model_entities import ModelStatus
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
class TestModelConfigConverter:
@pytest.fixture(autouse=True)
def patch_response_entity(self, mocker):
"""
Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation
and return a simple namespace object instead.
"""
def _factory(**kwargs):
return SimpleNamespace(**kwargs)
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity",
side_effect=_factory,
)
@pytest.fixture
def mock_app_config(self):
app_config = MagicMock()
app_config.tenant_id = "tenant_1"
model_config = MagicMock()
model_config.provider = "openai"
model_config.model = "gpt-4"
model_config.parameters = {"temperature": 0.5}
model_config.mode = None
app_config.model = model_config
return app_config
@pytest.fixture
def mock_provider_bundle(self):
bundle = MagicMock()
# configuration
configuration = MagicMock()
configuration.provider.provider = "openai"
configuration.get_current_credentials.return_value = {"api_key": "key"}
provider_model = MagicMock()
provider_model.status = ModelStatus.ACTIVE
configuration.get_provider_model.return_value = provider_model
bundle.configuration = configuration
# model type instance
model_type_instance = MagicMock()
model_schema = MagicMock()
model_schema.model_properties = {}
model_type_instance.get_model_schema.return_value = model_schema
bundle.model_type_instance = model_type_instance
return bundle
@pytest.fixture
def patch_provider_manager(self, mocker, mock_provider_bundle):
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
return mock_manager
# =============================
# Positive Scenarios
# =============================
def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager):
result = ModelConfigConverter.convert(mock_app_config)
assert result.provider == "openai"
assert result.model == "gpt-4"
assert result.mode == LLMMode.CHAT
assert result.parameters == {"temperature": 0.5}
assert result.stop == []
def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager):
mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]}
result = ModelConfigConverter.convert(mock_app_config)
assert result.parameters == {"temperature": 0.7}
assert result.stop == ["\n"]
def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker):
mock_app_config.model.mode = None
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
}
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
result = ModelConfigConverter.convert(mock_app_config)
assert result.mode == LLMMode.COMPLETION
def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker):
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
ModelPropertyKey.MODE: "invalid"
}
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
result = ModelConfigConverter.convert(mock_app_config)
assert result.mode == LLMMode.CHAT
# =============================
# Credential Errors
# =============================
def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
mock_provider_bundle.configuration.get_current_credentials.return_value = None
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
with pytest.raises(ProviderTokenNotInitError):
ModelConfigConverter.convert(mock_app_config)
# =============================
# Provider Model Errors
# =============================
def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
mock_provider_bundle.configuration.get_provider_model.return_value = None
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
with pytest.raises(ValueError):
ModelConfigConverter.convert(mock_app_config)
@pytest.mark.parametrize(
("status", "expected_exception"),
[
(ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError),
(ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError),
(ModelStatus.QUOTA_EXCEEDED, QuotaExceededError),
],
)
def test_convert_provider_model_status_errors(
self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception
):
mock_provider = MagicMock()
mock_provider.status = status
mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
with pytest.raises(expected_exception):
ModelConfigConverter.convert(mock_app_config)
# =============================
# Schema Errors
# =============================
def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
mock_provider_bundle.model_type_instance.get_model_schema.return_value = None
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
"core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager",
return_value=mock_manager,
)
with pytest.raises(ValueError):
ModelConfigConverter.convert(mock_app_config)
# =============================
# Edge Cases
# =============================
@pytest.mark.parametrize(
"parameters",
[
{},
{"stop": []},
{"stop": ["END"], "max_tokens": 100},
],
)
def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters):
mock_app_config.model.parameters = parameters.copy()
result = ModelConfigConverter.convert(mock_app_config)
if "stop" in parameters:
assert result.stop == parameters.get("stop")
expected_params = parameters.copy()
expected_params.pop("stop", None)
assert result.parameters == expected_params
else:
assert result.stop == []
assert result.parameters == parameters

View File

@ -0,0 +1,230 @@
from unittest.mock import MagicMock
import pytest
# Target
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
# -----------------------------
# Fixtures
# -----------------------------
@pytest.fixture
def valid_completion_params():
return {"temperature": 0.7, "stop": ["\n"]}
@pytest.fixture
def valid_model_list():
model = MagicMock()
model.model = "gpt-4"
model.model_properties = {"mode": "chat"}
return [model]
@pytest.fixture
def provider_entities():
provider = MagicMock()
provider.provider = "openai/gpt"
return [provider]
@pytest.fixture
def valid_config():
return {
"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}}
}
# -----------------------------
# Test Class
# -----------------------------
class TestModelConfigManager:
# ==========================================================
# convert
# ==========================================================
def test_convert_success(self, valid_config):
result = ModelConfigManager.convert(valid_config)
assert result.provider == "openai/gpt"
assert result.model == "gpt-4"
assert result.parameters == {"temperature": 0.5}
assert result.stop == ["END"]
def test_convert_missing_model(self):
with pytest.raises(ValueError, match="model is required"):
ModelConfigManager.convert({})
def test_convert_without_stop(self):
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}}
result = ModelConfigManager.convert(config)
assert result.stop == []
assert result.parameters == {"temperature": 0.9}
# ==========================================================
# validate_model_completion_params
# ==========================================================
@pytest.mark.parametrize(
"invalid_cp",
[None, "string", 123, []],
)
def test_validate_model_completion_params_invalid_type(self, invalid_cp):
with pytest.raises(ValueError, match="must be of object type"):
ModelConfigManager.validate_model_completion_params(invalid_cp)
def test_validate_model_completion_params_default_stop(self):
cp = {"temperature": 0.2}
result = ModelConfigManager.validate_model_completion_params(cp)
assert result["stop"] == []
def test_validate_model_completion_params_invalid_stop_type(self):
cp = {"stop": "invalid"}
with pytest.raises(ValueError, match="must be of list type"):
ModelConfigManager.validate_model_completion_params(cp)
def test_validate_model_completion_params_stop_length_exceeded(self):
cp = {"stop": [1, 2, 3, 4, 5]}
with pytest.raises(ValueError, match="less than 4"):
ModelConfigManager.validate_model_completion_params(cp)
# ==========================================================
# validate_and_set_defaults
# ==========================================================
def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list):
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config)
assert updated_config["model"]["mode"] == "chat"
assert keys == ["model"]
def test_validate_and_set_defaults_missing_model(self):
with pytest.raises(ValueError, match="model is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", {})
def test_validate_and_set_defaults_model_not_dict(self):
with pytest.raises(ValueError, match="object type"):
ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"})
def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities):
config = {"model": {"name": "gpt-4", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
with pytest.raises(ValueError, match="model.provider is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities):
config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
with pytest.raises(ValueError, match="model.provider is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities):
config = {"model": {"provider": "openai/gpt", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
with pytest.raises(ValueError, match="model.name is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities):
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = []
with pytest.raises(ValueError, match="must be in the specified model list"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list):
config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
with pytest.raises(ValueError, match="must be in the specified model list"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities):
model = MagicMock()
model.model = "gpt-4"
model.model_properties = {}
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model]
updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config)
assert updated_config["model"]["mode"] == "completion"
def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list):
config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}}
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
mock_factory.return_value.get_providers.return_value = provider_entities
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
with pytest.raises(ValueError, match="completion_params is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list):
"""
Covers branch where provider does not contain '/' and
ModelProviderID conversion is triggered (line 64).
"""
config = {
"model": {
"provider": "openai", # no slash -> triggers conversion
"name": "gpt-4",
"completion_params": {},
}
}
# Mock ModelProviderID to return formatted provider
mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID")
mock_provider_id.return_value = "openai/gpt"
# Mock provider factory
mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory")
provider_entity = MagicMock()
provider_entity.provider = "openai/gpt"
mock_factory.return_value.get_providers.return_value = [provider_entity]
# Mock provider manager
mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager")
mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list
updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config)
# Ensure conversion happened
mock_provider_id.assert_called_once_with("openai")
assert updated_config["model"]["provider"] == "openai/gpt"

View File

@ -0,0 +1,292 @@
from unittest.mock import MagicMock
import pytest
from core.app.app_config.easy_ui_based_app.prompt_template.manager import (
PromptTemplateConfigManager,
)
# -----------------------------
# Helpers
# -----------------------------
class DummyEnumValue:
def __init__(self, value):
self.value = value
class DummyPromptType:
def __init__(self):
self.SIMPLE = "simple"
self.ADVANCED = "advanced"
def value_of(self, value):
return value
def __iter__(self):
return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")])
# -----------------------------
# Convert Tests
# -----------------------------
class TestPromptTemplateConfigManagerConvert:
def test_convert_missing_prompt_type_raises(self):
with pytest.raises(ValueError, match="prompt_type is required"):
PromptTemplateConfigManager.convert({})
def test_convert_simple_prompt(self, mocker):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
mock_prompt_entity_cls.return_value = "simple_entity"
config = {"prompt_type": "simple", "pre_prompt": "hello"}
result = PromptTemplateConfigManager.convert(config)
assert result == "simple_entity"
mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello")
def test_convert_advanced_chat_valid(self, mocker):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mock_prompt_entity_cls.return_value = "advanced_entity"
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of",
return_value="role_enum",
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity",
return_value="chat_msg",
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity",
return_value="chat_template",
)
config = {
"prompt_type": "advanced",
"chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]},
}
result = PromptTemplateConfigManager.convert(config)
assert result == "advanced_entity"
@pytest.mark.parametrize(
"message",
[
{"text": 123, "role": "user"},
{"text": "hi", "role": 123},
],
)
def test_convert_advanced_invalid_message_fields(self, mocker, message):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
config = {
"prompt_type": "advanced",
"chat_prompt_config": {"prompt": [message]},
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.convert(config)
def test_convert_advanced_completion_with_roles(self, mocker):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mock_prompt_entity_cls.return_value = "advanced_entity"
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity",
return_value="completion_template",
)
config = {
"prompt_type": "advanced",
"completion_prompt_config": {
"prompt": {"text": "complete"},
"conversation_histories_role": {
"user_prefix": "U",
"assistant_prefix": "A",
},
},
}
result = PromptTemplateConfigManager.convert(config)
assert result == "advanced_entity"
# -----------------------------
# validate_and_set_defaults
# -----------------------------
class TestValidateAndSetDefaults:
def setup_method(self):
self.valid_model = {"mode": "chat"}
def _patch_prompt_type(self, mocker):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
return mock_prompt_entity_cls
def test_default_prompt_type_set(self, mocker):
self._patch_prompt_type(mocker)
config = {"model": self.valid_model}
result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
assert result["prompt_type"] == "simple"
assert isinstance(keys, list)
def test_invalid_prompt_type_raises(self, mocker):
class InvalidEnum(DummyPromptType):
def __iter__(self):
return iter([DummyEnumValue("valid")])
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = InvalidEnum()
mocker.patch(
"core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity",
mock_prompt_entity_cls,
)
config = {"prompt_type": "invalid", "model": self.valid_model}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_invalid_chat_prompt_config_type(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "simple",
"chat_prompt_config": "invalid",
"model": self.valid_model,
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_simple_mode_invalid_pre_prompt_type(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "simple",
"pre_prompt": 123,
"model": self.valid_model,
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_requires_one_config(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "advanced",
"chat_prompt_config": {},
"completion_prompt_config": {},
"model": {"mode": "chat"},
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_invalid_model_mode(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "advanced",
"chat_prompt_config": {"prompt": []},
"model": {"mode": "invalid"},
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_chat_prompt_length_exceeds(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "advanced",
"chat_prompt_config": {"prompt": [{}] * 11},
"model": {"mode": "chat"},
}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_completion_prefix_defaults_set_when_empty(self, mocker):
self._patch_prompt_type(mocker)
config = {
"prompt_type": "advanced",
"completion_prompt_config": {
"prompt": {"text": "hi"},
"conversation_histories_role": {
"user_prefix": "",
"assistant_prefix": "",
},
},
"model": {"mode": "completion"},
}
updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config)
roles = updated["completion_prompt_config"]["conversation_histories_role"]
assert roles["user_prefix"] == "Human"
assert roles["assistant_prefix"] == "Assistant"
# -----------------------------
# validate_post_prompt
# -----------------------------
class TestValidatePostPrompt:
@pytest.mark.parametrize("value", [None, ""])
def test_post_prompt_defaults(self, value):
config = {"post_prompt": value}
result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config)
assert result["post_prompt"] == ""
def test_post_prompt_invalid_type(self):
config = {"post_prompt": 123}
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config)

View File

@ -0,0 +1,286 @@
import pytest
from core.app.app_config.easy_ui_based_app.variables.manager import (
BasicVariablesConfigManager,
)
from dify_graph.variables.input_entities import VariableEntityType
class TestBasicVariablesConfigManagerConvert:
def test_convert_empty_config(self):
config = {}
variables, external = BasicVariablesConfigManager.convert(config)
assert variables == []
assert external == []
def test_convert_external_data_tools_enabled_and_disabled(self, mocker):
config = {
"external_data_tools": [
{"enabled": False},
{
"enabled": True,
"variable": "ext_var",
"type": "tool_type",
"config": {"k": "v"},
},
]
}
variables, external = BasicVariablesConfigManager.convert(config)
assert variables == []
assert len(external) == 1
assert external[0].variable == "ext_var"
assert external[0].type == "tool_type"
def test_convert_user_input_form_variable_types(self):
config = {
"user_input_form": [
{
VariableEntityType.TEXT_INPUT: {
"variable": "name",
"label": "Name",
"description": "desc",
"required": True,
"max_length": 50,
}
},
{
VariableEntityType.SELECT: {
"variable": "choice",
"label": "Choice",
"options": ["a", "b"],
}
},
{
VariableEntityType.EXTERNAL_DATA_TOOL: {
"variable": "ext",
"type": "tool",
"config": {"x": 1},
}
},
]
}
variables, external = BasicVariablesConfigManager.convert(config)
assert len(variables) == 2
assert len(external) == 1
def test_convert_external_data_tool_without_config_skipped(self):
config = {
"user_input_form": [
{
VariableEntityType.EXTERNAL_DATA_TOOL: {
"variable": "ext",
"type": "tool",
}
}
]
}
variables, external = BasicVariablesConfigManager.convert(config)
assert variables == []
assert external == []
class TestValidateVariablesAndSetDefaults:
def test_validate_sets_empty_user_input_form_if_missing(self):
config = {}
updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
assert updated["user_input_form"] == []
assert "user_input_form" in keys
def test_validate_user_input_form_not_list_raises(self):
config = {"user_input_form": "invalid"}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_invalid_key_raises(self):
config = {"user_input_form": [{"invalid": {}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_missing_label_raises(self):
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_label_not_string_raises(self):
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_missing_variable_raises(self):
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_variable_not_string_raises(self):
config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
@pytest.mark.parametrize(
"variable_name",
["1invalid", "invalid space", "", None],
)
def test_validate_variable_invalid_pattern_raises(self, variable_name):
config = {
"user_input_form": [
{
VariableEntityType.TEXT_INPUT: {
"label": "Name",
"variable": variable_name,
}
}
]
}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_required_default_and_type(self):
config = {
"user_input_form": [
{
VariableEntityType.TEXT_INPUT: {
"label": "Name",
"variable": "valid_name",
}
}
]
}
updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False
def test_validate_required_not_bool_raises(self):
config = {
"user_input_form": [
{
VariableEntityType.TEXT_INPUT: {
"label": "Name",
"variable": "valid_name",
"required": "yes",
}
}
]
}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_select_options_default_not_in_options_raises(self):
config = {
"user_input_form": [
{
VariableEntityType.SELECT: {
"label": "Choice",
"variable": "choice",
"options": ["a", "b"],
"default": "c",
}
}
]
}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
def test_validate_select_options_not_list_raises(self):
config = {
"user_input_form": [
{
VariableEntityType.SELECT: {
"label": "Choice",
"variable": "choice",
"options": "not_list",
}
}
]
}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_variables_and_set_defaults(config)
class TestValidateExternalDataToolsAndSetDefaults:
def test_validate_sets_empty_external_data_tools_if_missing(self):
config = {}
updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
assert updated["external_data_tools"] == []
assert "external_data_tools" in keys
def test_validate_external_data_tools_not_list_raises(self):
config = {"external_data_tools": "invalid"}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
def test_validate_disabled_tool_skipped(self, mocker):
config = {"external_data_tools": [{"enabled": False}]}
spy = mocker.patch(
"core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config"
)
updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
spy.assert_not_called()
assert updated["external_data_tools"][0]["enabled"] is False
def test_validate_enabled_tool_missing_type_raises(self):
config = {"external_data_tools": [{"enabled": True, "config": {}}]}
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
def test_validate_enabled_tool_calls_factory(self, mocker):
config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]}
spy = mocker.patch(
"core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config"
)
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config)
spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1})
class TestValidateAndSetDefaultsIntegration:
def test_validate_and_set_defaults_calls_both(self, mocker):
config = {}
spy_var = mocker.patch.object(
BasicVariablesConfigManager,
"validate_variables_and_set_defaults",
return_value=(config, ["user_input_form"]),
)
spy_ext = mocker.patch.object(
BasicVariablesConfigManager,
"validate_external_data_tools_and_set_defaults",
return_value=(config, ["external_data_tools"]),
)
updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config)
spy_var.assert_called_once()
spy_ext.assert_called_once()
assert "user_input_form" in keys
assert "external_data_tools" in keys
assert updated == config

View File

@ -0,0 +1,115 @@
import pytest
from core.app.app_config.entities import TextToSpeechEntity
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
class TestAdditionalFeatureManagers:
def test_opening_statement_validate_defaults(self):
config, keys = OpeningStatementConfigManager.validate_and_set_defaults({})
assert config["opening_statement"] == ""
assert config["suggested_questions"] == []
assert set(keys) == {"opening_statement", "suggested_questions"}
def test_opening_statement_validate_types(self):
with pytest.raises(ValueError):
OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123})
with pytest.raises(ValueError):
OpeningStatementConfigManager.validate_and_set_defaults(
{"opening_statement": "hi", "suggested_questions": "bad"}
)
with pytest.raises(ValueError):
OpeningStatementConfigManager.validate_and_set_defaults(
{"opening_statement": "hi", "suggested_questions": [1]}
)
def test_opening_statement_convert(self):
opening, questions = OpeningStatementConfigManager.convert(
{"opening_statement": "hello", "suggested_questions": ["q1"]}
)
assert opening == "hello"
assert questions == ["q1"]
def test_retrieval_resource_validate(self):
config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({})
assert config["retriever_resource"]["enabled"] is False
assert keys == ["retriever_resource"]
with pytest.raises(ValueError):
RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"})
with pytest.raises(ValueError):
RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}})
def test_retrieval_resource_convert(self):
assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True
assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False
def test_speech_to_text_validate_and_convert(self):
config, keys = SpeechToTextConfigManager.validate_and_set_defaults({})
assert config["speech_to_text"]["enabled"] is False
assert keys == ["speech_to_text"]
with pytest.raises(ValueError):
SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"})
with pytest.raises(ValueError):
SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}})
assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True
assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False
def test_suggested_questions_after_answer_validate_and_convert(self):
config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({})
assert config["suggested_questions_after_answer"]["enabled"] is False
assert keys == ["suggested_questions_after_answer"]
with pytest.raises(ValueError):
SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
{"suggested_questions_after_answer": "bad"}
)
with pytest.raises(ValueError):
SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
{"suggested_questions_after_answer": {"enabled": "yes"}}
)
assert (
SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}})
is True
)
assert (
SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}})
is False
)
def test_text_to_speech_validate_and_convert(self):
config, keys = TextToSpeechConfigManager.validate_and_set_defaults({})
assert config["text_to_speech"]["enabled"] is False
assert keys == ["text_to_speech"]
with pytest.raises(ValueError):
TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"})
with pytest.raises(ValueError):
TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}})
result = TextToSpeechConfigManager.convert(
{"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}}
)
assert isinstance(result, TextToSpeechEntity)
assert result.voice == "v"
assert result.language == "en"
def test_more_like_this_convert_and_validate(self):
config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({})
assert config["more_like_this"]["enabled"] is False
assert keys == ["more_like_this"]
assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True
assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False
with pytest.raises(ValueError):
MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"})

View File

@ -0,0 +1,180 @@
from collections import UserDict
from unittest.mock import MagicMock
import pytest
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
class TestBaseAppConfigManager:
@pytest.fixture
def mock_config_dict(self):
return {"key": "value", "another": 123}
@pytest.fixture
def mock_app_additional_features(self, mocker):
mock_instance = MagicMock()
mocker.patch(
"core.app.app_config.base_app_config_manager.AppAdditionalFeatures",
return_value=mock_instance,
)
return mock_instance
@pytest.fixture
def mock_managers(self, mocker):
retrieval = mocker.patch(
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
return_value="retrieval_result",
)
file_upload = mocker.patch(
"core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert",
return_value="file_upload_result",
)
opening_statement = mocker.patch(
"core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert",
return_value=("opening_result", "suggested_result"),
)
suggested_after = mocker.patch(
"core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert",
return_value="suggested_after_result",
)
more_like_this = mocker.patch(
"core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert",
return_value="more_like_this_result",
)
speech_to_text = mocker.patch(
"core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert",
return_value="speech_to_text_result",
)
text_to_speech = mocker.patch(
"core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert",
return_value="text_to_speech_result",
)
return {
"retrieval": retrieval,
"file_upload": file_upload,
"opening_statement": opening_statement,
"suggested_after": suggested_after,
"more_like_this": more_like_this,
"speech_to_text": speech_to_text,
"text_to_speech": text_to_speech,
}
@pytest.mark.parametrize(
("app_mode", "expected_is_vision"),
[
("CHAT", True),
("COMPLETION", True),
("AGENT_CHAT", True),
("OTHER", False),
],
)
def test_convert_features_all_modes(
self,
mocker,
mock_config_dict,
mock_app_additional_features,
mock_managers,
app_mode,
expected_is_vision,
):
# Arrange
mock_app_mode = MagicMock()
mock_app_mode.CHAT = "CHAT"
mock_app_mode.COMPLETION = "COMPLETION"
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
mocker.patch(
"core.app.app_config.base_app_config_manager.AppMode",
mock_app_mode,
)
# Act
result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode)
# Assert
assert result == mock_app_additional_features
mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["file_upload"].assert_called_once()
_, kwargs = mock_managers["file_upload"].call_args
assert kwargs["config"] == dict(mock_config_dict.items())
assert kwargs["is_vision"] is expected_is_vision
mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items()))
def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers):
# Arrange
empty_config = {}
mock_app_mode = MagicMock()
mock_app_mode.CHAT = "CHAT"
mock_app_mode.COMPLETION = "COMPLETION"
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
mocker.patch(
"core.app.app_config.base_app_config_manager.AppMode",
mock_app_mode,
)
# Act
result = BaseAppConfigManager.convert_features(empty_config, "CHAT")
# Assert
assert result == mock_app_additional_features
for manager in mock_managers.values():
assert manager.called
@pytest.mark.parametrize(
"invalid_config",
[
None,
"string",
123,
12.34,
[],
],
)
def test_convert_features_invalid_config_raises(self, invalid_config):
# Act & Assert
with pytest.raises((TypeError, AttributeError)):
BaseAppConfigManager.convert_features(invalid_config, "CHAT")
def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict):
# Arrange
mocker.patch(
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
side_effect=RuntimeError("manager failure"),
)
# Act & Assert
with pytest.raises(RuntimeError):
BaseAppConfigManager.convert_features(mock_config_dict, "CHAT")
def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers):
# Arrange
class CustomMapping(UserDict):
pass
custom_config = CustomMapping({"a": 1})
mock_app_mode = MagicMock()
mock_app_mode.CHAT = "CHAT"
mock_app_mode.COMPLETION = "COMPLETION"
mock_app_mode.AGENT_CHAT = "AGENT_CHAT"
mocker.patch(
"core.app.app_config.base_app_config_manager.AppMode",
mock_app_mode,
)
# Act
result = BaseAppConfigManager.convert_features(custom_config, "CHAT")
# Assert
assert result == mock_app_additional_features
for manager in mock_managers.values():
assert manager.called

View File

@ -0,0 +1,43 @@
import pytest
from core.app.app_config.entities import (
DatasetRetrieveConfigEntity,
PromptTemplateEntity,
)
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
class TestAppConfigEntities:
def test_variable_entity_coerces_none_description_and_options(self):
entity = VariableEntity(
variable="query",
label="Query",
description=None,
type=VariableEntityType.TEXT_INPUT,
options=None,
)
assert entity.description == ""
assert entity.options == []
def test_variable_entity_rejects_invalid_json_schema(self):
with pytest.raises(ValueError):
VariableEntity(
variable="query",
label="Query",
type=VariableEntityType.TEXT_INPUT,
json_schema={"type": "string", "minLength": "bad"},
)
def test_prompt_template_value_of(self):
assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE
with pytest.raises(ValueError):
PromptTemplateEntity.PromptType.value_of("missing")
def test_dataset_retrieve_strategy_value_of(self):
assert (
DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single")
== DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
)
with pytest.raises(ValueError):
DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing")

View File

@ -0,0 +1,222 @@
import pytest
from core.app.app_config.workflow_ui_based_app.variables.manager import (
WorkflowVariablesConfigManager,
)
# =============================
# Fixtures
# =============================
@pytest.fixture
def mock_workflow(mocker):
workflow = mocker.MagicMock()
workflow.graph_dict = {"nodes": []}
return workflow
@pytest.fixture
def mock_variable_entity(mocker):
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity")
@pytest.fixture
def mock_rag_entity(mocker):
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity")
# =============================
# Test Convert (user_input_form)
# =============================
class TestWorkflowVariablesConfigManagerConvert:
def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity):
# Arrange
input_variables = [{"name": "var1"}, {"name": "var2"}]
mock_workflow.user_input_form.return_value = input_variables
mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert(mock_workflow)
# Assert
assert result == [{"validated": v} for v in input_variables]
assert mock_variable_entity.model_validate.call_count == 2
def test_convert_empty_list(self, mock_workflow, mock_variable_entity):
# Arrange
mock_workflow.user_input_form.return_value = []
# Act
result = WorkflowVariablesConfigManager.convert(mock_workflow)
# Assert
assert result == []
mock_variable_entity.model_validate.assert_not_called()
def test_convert_none_returned_raises(self, mock_workflow):
# Arrange
mock_workflow.user_input_form.return_value = None
# Act & Assert
with pytest.raises(TypeError):
WorkflowVariablesConfigManager.convert(mock_workflow)
def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity):
# Arrange
mock_workflow.user_input_form.return_value = [{"invalid": "data"}]
mock_variable_entity.model_validate.side_effect = ValueError("validation error")
# Act & Assert
with pytest.raises(ValueError):
WorkflowVariablesConfigManager.convert(mock_workflow)
# =============================
# Test convert_rag_pipeline_variable
# =============================
class TestWorkflowVariablesConfigManagerConvertRag:
def test_no_rag_pipeline_variables(self, mock_workflow):
# Arrange
mock_workflow.rag_pipeline_variables = []
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert result == []
def test_rag_pipeline_none(self, mock_workflow):
# Arrange
mock_workflow.rag_pipeline_variables = None
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert result == []
def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": "node1"},
]
mock_workflow.graph_dict = {"nodes": []}
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}]
def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": "node1"},
{"variable": "var2", "belong_to_node_id": "node1"},
]
mock_workflow.graph_dict = {
"nodes": [
{
"id": "node1",
"data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}},
}
]
}
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert len(result) == 1
assert result[0]["validated"]["variable"] == "var2"
def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": "node1"},
{"variable": "var2", "belong_to_node_id": "node1"},
]
mock_workflow.graph_dict = {
"nodes": [
{
"id": "node1",
"data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}},
}
]
}
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert len(result) == 1
assert result[0]["validated"]["variable"] == "var2"
@pytest.mark.parametrize(
("belong_to_node_id", "expected_count"),
[
("node1", 1),
("shared", 1),
("other_node", 0),
],
)
def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": belong_to_node_id},
]
mock_workflow.graph_dict = {"nodes": []}
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert len(result) == expected_count
def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": "node1"},
]
mock_workflow.graph_dict = {
"nodes": [
{
"id": "node1",
"data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}},
}
]
}
mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x}
# Act
result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")
# Assert
assert len(result) == 1
def test_validation_error_propagates(self, mock_workflow, mock_rag_entity):
# Arrange
mock_workflow.rag_pipeline_variables = [
{"variable": "var1", "belong_to_node_id": "node1"},
]
mock_workflow.graph_dict = {"nodes": []}
mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed")
# Act & Assert
with pytest.raises(RuntimeError):
WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1")

View File

@ -0,0 +1,12 @@
from core.app.entities.queue_entities import QueueStopEvent
class TestQueueEntities:
def test_get_stop_reason_for_known_stop_by(self):
event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
assert event.get_stop_reason() == "Stopped by user."
def test_get_stop_reason_for_unknown_stop_by(self):
event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
event.stopped_by = "unknown"
assert event.get_stop_reason() == "Stopped by unknown reason."

View File

@ -0,0 +1,17 @@
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
class TestRagPipelineInvokeEntity:
def test_defaults_and_fields(self):
entity = RagPipelineInvokeEntity(
pipeline_id="pipe-1",
application_generate_entity={"foo": "bar"},
user_id="user-1",
tenant_id="tenant-1",
workflow_id="workflow-1",
streaming=True,
)
assert entity.workflow_execution_id is None
assert entity.workflow_thread_pool_id is None
assert entity.streaming is True

View File

@ -0,0 +1,78 @@
from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamEvent,
)
from dify_graph.enums import WorkflowNodeExecutionStatus
class TestTaskEntities:
def test_node_start_to_ignore_detail_dict(self):
data = NodeStartStreamResponse.Data(
id="exec-1",
node_id="node-1",
node_type="answer",
title="Answer",
index=1,
predecessor_node_id=None,
inputs={"foo": "bar"},
created_at=1,
)
response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
payload = response.to_ignore_detail_dict()
assert payload["event"] == StreamEvent.NODE_STARTED.value
assert payload["data"]["inputs"] is None
assert payload["data"]["extras"] == {}
def test_node_finish_to_ignore_detail_dict(self):
data = NodeFinishStreamResponse.Data(
id="exec-1",
node_id="node-1",
node_type="answer",
title="Answer",
index=1,
predecessor_node_id=None,
inputs={"foo": "bar"},
process_data={"step": 1},
outputs={"answer": "ok"},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
elapsed_time=0.1,
created_at=1,
finished_at=2,
)
response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
payload = response.to_ignore_detail_dict()
assert payload["event"] == StreamEvent.NODE_FINISHED.value
assert payload["data"]["inputs"] is None
assert payload["data"]["outputs"] is None
assert payload["data"]["files"] == []
def test_node_retry_to_ignore_detail_dict(self):
data = NodeRetryStreamResponse.Data(
id="exec-1",
node_id="node-1",
node_type="answer",
title="Answer",
index=1,
predecessor_node_id=None,
inputs={"foo": "bar"},
process_data={"step": 1},
outputs={"answer": "ok"},
status=WorkflowNodeExecutionStatus.RETRY,
elapsed_time=0.1,
created_at=1,
finished_at=2,
retry_index=2,
)
response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data)
payload = response.to_ignore_detail_dict()
assert payload["event"] == StreamEvent.NODE_RETRY.value
assert payload["data"]["retry_index"] == 2
assert payload["data"]["outputs"] is None

View File

@ -0,0 +1,163 @@
import logging
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
class TestAnnotationReplyFeature:
def test_query_returns_none_when_setting_missing(self):
feature = AnnotationReplyFeature()
with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db:
mock_db.session.scalar.return_value = None
result = feature.query(
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
message=SimpleNamespace(id="msg-1"),
query="hi",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
assert result is None
def test_query_returns_none_when_binding_missing(self):
feature = AnnotationReplyFeature()
annotation_setting = SimpleNamespace(collection_binding_detail=None)
with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db:
mock_db.session.scalar.return_value = annotation_setting
result = feature.query(
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
message=SimpleNamespace(id="msg-1"),
query="hi",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
assert result is None
def test_query_returns_annotation_and_records_history_for_api(self):
feature = AnnotationReplyFeature()
annotation_setting = SimpleNamespace(
score_threshold=None,
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
)
dataset_binding = SimpleNamespace(id="binding-1")
annotation = SimpleNamespace(
id="ann-1",
question_text="question",
content="content",
account_id="acct-1",
account=SimpleNamespace(name="Alice"),
)
document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8})
vector_instance = Mock()
vector_instance.search_by_vector.return_value = [document]
with (
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
patch(
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
) as mock_binding_service,
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
patch(
"core.app.features.annotation_reply.annotation_reply.AppAnnotationService"
) as mock_annotation_service,
):
mock_db.session.scalar.return_value = annotation_setting
mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding
mock_vector.return_value = vector_instance
mock_annotation_service.get_annotation_by_id.return_value = annotation
result = feature.query(
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
message=SimpleNamespace(id="msg-1"),
query="hi",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
assert result == annotation
mock_annotation_service.add_annotation_history.assert_called_once()
_, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0]
assert from_source == "api"
assert score == 0.8
def test_query_returns_annotation_and_records_history_for_console(self):
feature = AnnotationReplyFeature()
annotation_setting = SimpleNamespace(
score_threshold=0.5,
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
)
dataset_binding = SimpleNamespace(id="binding-1")
annotation = SimpleNamespace(
id="ann-1",
question_text="question",
content="content",
account_id="acct-1",
account=None,
)
document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6})
vector_instance = Mock()
vector_instance.search_by_vector.return_value = [document]
with (
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
patch(
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
) as mock_binding_service,
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
patch(
"core.app.features.annotation_reply.annotation_reply.AppAnnotationService"
) as mock_annotation_service,
):
mock_db.session.scalar.return_value = annotation_setting
mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding
mock_vector.return_value = vector_instance
mock_annotation_service.get_annotation_by_id.return_value = annotation
result = feature.query(
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
message=SimpleNamespace(id="msg-1"),
query="hi",
user_id="user-1",
invoke_from=InvokeFrom.EXPLORE,
)
assert result == annotation
_, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0]
assert from_source == "console"
def test_query_logs_and_returns_none_on_exception(self, caplog):
feature = AnnotationReplyFeature()
annotation_setting = SimpleNamespace(
score_threshold=None,
collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"),
)
with (
patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db,
patch(
"core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService"
) as mock_binding_service,
patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector,
):
mock_db.session.scalar.return_value = annotation_setting
mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1")
mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom")
with caplog.at_level(logging.WARNING):
result = feature.query(
app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
message=SimpleNamespace(id="msg-1"),
query="hi",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
assert result is None
assert "Query annotation failed" in caplog.text

View File

@ -0,0 +1,30 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
class TestHostingModerationFeature:
def test_check_aggregates_text_and_calls_moderation(self):
application_generate_entity = Mock()
application_generate_entity.model_conf = {"model": "mock"}
application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1")
prompt_messages = [
SimpleNamespace(content="hello"),
SimpleNamespace(content=123),
SimpleNamespace(content="world"),
]
with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check:
mock_check.return_value = True
feature = HostingModerationFeature()
result = feature.check(application_generate_entity, prompt_messages)
assert result is True
mock_check.assert_called_once_with(
tenant_id="tenant-1",
model_config=application_generate_entity.model_conf,
text="hello\nworld\n",
)

View File

@ -0,0 +1,19 @@
from core.app.layers.suspend_layer import SuspendLayer
from dify_graph.graph_events.graph import GraphRunPausedEvent
class TestSuspendLayer:
def test_on_event_accepts_paused_event(self):
layer = SuspendLayer()
assert layer.is_paused() is False
layer.on_graph_start()
assert layer.is_paused() is False
layer.on_event(GraphRunPausedEvent())
assert layer.is_paused() is True
def test_on_event_ignores_other_events(self):
layer = SuspendLayer()
layer.on_graph_start()
initial_state = layer.is_paused()
layer.on_event(object())
assert layer.is_paused() is initial_state

View File

@ -0,0 +1,98 @@
from unittest.mock import Mock, patch
from core.app.layers.timeslice_layer import TimeSliceLayer
from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import SchedulerCommand
class TestTimeSliceLayer:
def test_init_starts_scheduler_when_not_running(self):
scheduler = Mock()
scheduler.running = False
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
_ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock()))
scheduler.start.assert_called_once()
def test_on_graph_start_adds_job_for_time_slice(self):
scheduler = Mock()
scheduler.running = True
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=3,
)
cfs_plan_scheduler = Mock(plan=plan)
with (
patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler),
patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid,
):
mock_uuid.return_value.hex = "job-1"
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer.on_graph_start()
assert layer.schedule_id == "job-1"
scheduler.add_job.assert_called_once()
def test_on_graph_end_removes_job(self):
scheduler = Mock()
scheduler.running = True
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=3,
)
cfs_plan_scheduler = Mock(plan=plan)
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer.schedule_id = "job-1"
layer.on_graph_end(None)
scheduler.remove_job.assert_called_once_with("job-1")
def test_checker_job_removes_when_stopped(self):
scheduler = Mock()
scheduler.running = True
cfs_plan_scheduler = Mock(plan=Mock())
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer.stopped = True
layer._checker_job("job-1")
scheduler.remove_job.assert_called_once_with("job-1")
def test_checker_job_handles_resource_limit_without_command_channel(self):
scheduler = Mock()
scheduler.running = True
cfs_plan_scheduler = Mock(plan=Mock())
cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED
with (
patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler),
patch("core.app.layers.timeslice_layer.logger") as mock_logger,
):
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer._checker_job("job-1")
scheduler.remove_job.assert_called_once_with("job-1")
mock_logger.exception.assert_called_once()
def test_checker_job_sends_pause_command(self):
scheduler = Mock()
scheduler.running = True
cfs_plan_scheduler = Mock(plan=Mock())
cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED
with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler):
layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler)
layer.command_channel = Mock()
layer._checker_job("job-1")
scheduler.remove_job.assert_called_once_with("job-1")
layer.command_channel.send_command.assert_called_once()
sent_command = layer.command_channel.send_command.call_args[0][0]
assert isinstance(sent_command, GraphEngineCommand)
assert sent_command.command_type == CommandType.PAUSE

View File

@ -0,0 +1,106 @@
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.app.layers.trigger_post_layer import TriggerPostLayer
from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent
from models.enums import WorkflowTriggerStatus
class TestTriggerPostLayer:
def test_on_event_updates_trigger_log(self):
trigger_log = SimpleNamespace(
status=None,
workflow_run_id=None,
outputs=None,
elapsed_time=None,
total_tokens=None,
finished_at=None,
)
runtime_state = SimpleNamespace(
outputs={"answer": "ok"},
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
total_tokens=12,
)
with (
patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory,
patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls,
patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime,
):
mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC)
session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = session
repo = Mock()
repo.get_by_id.return_value = trigger_log
mock_repo_cls.return_value = repo
layer = TriggerPostLayer(
cfs_plan_scheduler_entity=Mock(),
start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10),
trigger_log_id="log-1",
)
layer.initialize(runtime_state, Mock())
layer.on_event(GraphRunSucceededEvent())
assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED
assert trigger_log.workflow_run_id == "run-1"
assert trigger_log.outputs is not None
assert trigger_log.elapsed_time is not None
assert trigger_log.total_tokens == 12
assert trigger_log.finished_at is not None
repo.update.assert_called_once_with(trigger_log)
session.commit.assert_called_once()
def test_on_event_handles_missing_trigger_log(self):
runtime_state = SimpleNamespace(
outputs={},
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
total_tokens=0,
)
with (
patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory,
patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls,
patch("core.app.layers.trigger_post_layer.logger") as mock_logger,
):
session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = session
repo = Mock()
repo.get_by_id.return_value = None
mock_repo_cls.return_value = repo
layer = TriggerPostLayer(
cfs_plan_scheduler_entity=Mock(),
start_time=datetime(2026, 2, 20, tzinfo=UTC),
trigger_log_id="missing",
)
layer.initialize(runtime_state, Mock())
layer.on_event(GraphRunFailedEvent(error="boom"))
mock_logger.exception.assert_called_once()
session.commit.assert_not_called()
def test_on_event_ignores_non_status_events(self):
runtime_state = SimpleNamespace(
outputs={},
system_variable=SimpleNamespace(workflow_execution_id="run-1"),
total_tokens=0,
)
with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory:
layer = TriggerPostLayer(
cfs_plan_scheduler_entity=Mock(),
start_time=datetime(2026, 2, 20, tzinfo=UTC),
trigger_log_id="log-1",
)
layer.initialize(runtime_state, Mock())
layer.on_event(Mock())
mock_session_factory.create_session.assert_not_called()

View File

@ -0,0 +1,91 @@
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from core.app.entities.queue_entities import QueueErrorEvent
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.errors.error import QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from models.enums import MessageStatus
class TestBasedGenerateTaskPipeline:
@pytest.fixture
def pipeline(self):
app_config = SimpleNamespace(
tenant_id="tenant-1",
app_id="app-1",
sensitive_word_avoidance=None,
)
app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config)
return BasedGenerateTaskPipeline(
application_generate_entity=app_generate_entity,
queue_manager=Mock(),
stream=True,
)
def test_error_to_desc_quota_exceeded(self, pipeline):
message = pipeline._error_to_desc(QuotaExceededError())
assert "quota" in message.lower()
def test_handle_error_wraps_invoke_authorization(self, pipeline):
event = QueueErrorEvent(error=InvokeAuthorizationError())
err = pipeline.handle_error(event=event)
assert isinstance(err, InvokeAuthorizationError)
assert str(err) == "Incorrect API key provided"
def test_handle_error_preserves_invoke_error(self, pipeline):
event = QueueErrorEvent(error=InvokeError("bad"))
err = pipeline.handle_error(event=event)
assert err is event.error
def test_handle_error_updates_message_when_found(self, pipeline):
event = QueueErrorEvent(error=ValueError("oops"))
message = SimpleNamespace(status=MessageStatus.NORMAL, error=None)
session = Mock()
session.scalar.return_value = message
err = pipeline.handle_error(event=event, session=session, message_id="msg-1")
assert err is event.error
assert message.status == MessageStatus.ERROR
assert message.error == "oops"
def test_handle_error_returns_err_when_message_missing(self, pipeline):
event = QueueErrorEvent(error=ValueError("oops"))
session = Mock()
session.scalar.return_value = None
err = pipeline.handle_error(event=event, session=session, message_id="msg-1")
assert err is event.error
def test_error_to_stream_response_and_ping(self, pipeline):
error_response = pipeline.error_to_stream_response(ValueError("boom"))
ping_response = pipeline.ping_stream_response()
assert error_response.task_id == "task-1"
assert ping_response.task_id == "task-1"
def test_handle_output_moderation_when_flagged(self, pipeline):
handler = Mock()
handler.moderation_completion.return_value = ("filtered", True)
pipeline.output_moderation_handler = handler
result = pipeline.handle_output_moderation_when_task_finished("raw")
assert result == "filtered"
handler.stop_thread.assert_called_once()
assert pipeline.output_moderation_handler is None
def test_handle_output_moderation_when_not_flagged(self, pipeline):
handler = Mock()
handler.moderation_completion.return_value = ("safe", False)
pipeline.output_moderation_handler = handler
result = pipeline.handle_output_moderation_when_task_finished("raw")
assert result is None
handler.stop_thread.assert_called_once()
assert pipeline.output_moderation_handler is None

View File

@ -0,0 +1,11 @@
from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError
class TestTaskPipelineExceptions:
def test_record_not_found_error_message(self):
err = RecordNotFoundError("Message", "msg-1")
assert str(err) == "Message with id msg-1 not found"
def test_workflow_run_not_found_error_message(self):
err = WorkflowRunNotFoundError("run-1")
assert str(err) == "WorkflowRun with id run-1 not found"

View File

@ -1,12 +1,16 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import current_app
from flask import Flask, current_app
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from models.model import AppMode
class TestMessageCycleManagerOptimization:
@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization:
assert result == StreamEvent.MESSAGE
mock_session.scalar.assert_called_once()
def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager):
"""Return MESSAGE_FILE directly from in-memory cache without opening a DB session."""
message_cycle_manager._message_has_file.add("cached-message")
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
result = message_cycle_manager.get_message_event_type("cached-message")
assert result == StreamEvent.MESSAGE_FILE
mock_session_factory.create_session.assert_not_called()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization:
assert chunk2_response.event == StreamEvent.MESSAGE
assert chunk1_response.answer == "Chunk 1"
assert chunk2_response.answer == "Chunk 2"
def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager):
"""Return None when completion entities are used for conversation naming.
Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity.
Returns: None, indicating no name generation for completion apps.
Side effects: None expected.
"""
class DummyCompletion:
pass
with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion):
message_cycle_manager._application_generate_entity = DummyCompletion()
result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi")
assert result is None
def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager):
"""Spawn background generation thread for the first chat message."""
message_cycle_manager._application_generate_entity.is_new_conversation = True
message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True}
flask_app = object()
class DummyTimer:
def __init__(self, interval, function, args=None, kwargs=None):
self.interval = interval
self.function = function
self.args = args or []
self.kwargs = kwargs
self.daemon = False
self.started = False
def start(self):
self.started = True
with (
patch(
"core.app.task_pipeline.message_cycle_manager.current_app",
new=SimpleNamespace(_get_current_object=lambda: flask_app),
),
patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer),
):
thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello")
assert isinstance(thread, DummyTimer)
assert thread.interval == 1
assert thread.function == message_cycle_manager._generate_conversation_name_worker
assert thread.started is True
assert thread.daemon is True
assert thread.kwargs["flask_app"] is flask_app
assert thread.kwargs["conversation_id"] == "conv-1"
assert thread.kwargs["query"] == "hello"
assert message_cycle_manager._application_generate_entity.is_new_conversation is False
def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager):
"""Skip thread creation when auto naming is disabled but still mark conversation as not new."""
message_cycle_manager._application_generate_entity.is_new_conversation = True
message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False}
with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer:
result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello")
assert result is None
assert message_cycle_manager._application_generate_entity.is_new_conversation is False
mock_timer.assert_not_called()
def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager):
"""Return early when the conversation cannot be found."""
flask_app = Flask(__name__)
db_session = Mock()
db_session.scalar.return_value = None
with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db:
mock_db.session = db_session
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello")
db_session.commit.assert_not_called()
db_session.close.assert_not_called()
def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager):
"""Return early when non-completion conversation has no app relation."""
flask_app = Flask(__name__)
conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id")
db_session = Mock()
db_session.scalar.return_value = conversation
with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db:
mock_db.session = db_session
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
db_session.commit.assert_not_called()
db_session.close.assert_not_called()
def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager):
"""Use cached conversation name when present and avoid LLM call."""
flask_app = Flask(__name__)
conversation = SimpleNamespace(
mode=AppMode.CHAT,
app=SimpleNamespace(tenant_id="tenant-1"),
app_id="app-id",
name="",
)
db_session = Mock()
db_session.scalar.return_value = conversation
with (
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
):
mock_db.session = db_session
mock_redis.get.return_value = b"cached-title"
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
assert conversation.name == "cached-title"
db_session.commit.assert_called_once()
db_session.close.assert_called_once()
mock_llm_generator.generate_conversation_name.assert_not_called()
mock_redis.setex.assert_not_called()
def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager):
"""Generate conversation name and write it to redis cache on cache miss."""
flask_app = Flask(__name__)
conversation = SimpleNamespace(
mode=AppMode.CHAT,
app=SimpleNamespace(tenant_id="tenant-1"),
app_id="app-id",
name="",
)
db_session = Mock()
db_session.scalar.return_value = conversation
with (
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
):
mock_db.session = db_session
mock_redis.get.return_value = None
mock_llm_generator.generate_conversation_name.return_value = "generated-title"
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello")
assert conversation.name == "generated-title"
db_session.commit.assert_called_once()
db_session.close.assert_called_once()
mock_redis.setex.assert_called_once()
def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager):
"""Fallback to truncated query when LLM generation fails."""
flask_app = Flask(__name__)
conversation = SimpleNamespace(
mode=AppMode.CHAT,
app=SimpleNamespace(tenant_id="tenant-1"),
app_id="app-id",
name="",
)
db_session = Mock()
db_session.scalar.return_value = conversation
long_query = "q" * 60
with (
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis,
patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator,
patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config,
patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger,
):
mock_db.session = db_session
mock_redis.get.return_value = None
mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed")
mock_dify_config.DEBUG = True
message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query)
assert conversation.name == (long_query[:47] + "...")
db_session.commit.assert_called_once()
db_session.close.assert_called_once()
mock_logger.exception.assert_called_once()
def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager):
"""Populate task metadata from annotation reply events.
Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService.
Returns: The fetched annotation object.
Side effects: Updates metadata.annotation_reply with id and account name.
"""
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
annotation = SimpleNamespace(
id="ann-1",
account_id="acct-1",
account=SimpleNamespace(name="Alice"),
)
with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service:
mock_service.get_annotation_by_id.return_value = annotation
result = message_cycle_manager.handle_annotation_reply(
QueueAnnotationReplyEvent(message_annotation_id="ann-1")
)
assert result == annotation
assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1"
assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice"
def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager):
"""Return None and keep metadata unchanged when annotation is not found."""
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service:
mock_service.get_annotation_by_id.return_value = None
result = message_cycle_manager.handle_annotation_reply(
QueueAnnotationReplyEvent(message_annotation_id="missing")
)
assert result is None
assert message_cycle_manager._task_state.metadata.annotation_reply is None
def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager):
"""Merge retriever resources, deduplicate, and preserve ordering positions.
Args: message_cycle_manager with show_retrieve_source enabled and existing metadata.
Returns: None.
Side effects: Updates metadata.retriever_resources with unique items and positions.
"""
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(
additional_features=SimpleNamespace(show_retrieve_source=True)
)
existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing]))
duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2")
event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource])
message_cycle_manager.handle_retriever_resources(event)
assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2
assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1
assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2
def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager):
"""Build a stream response with a signed tool file URL.
Args: message_cycle_manager with mocked Session/db and sign_tool_file.
Returns: MessageStreamResponse with signed url and belongs_to normalized to user.
Side effects: Calls sign_tool_file for tool file ids.
"""
message_cycle_manager._application_generate_entity.task_id = "task-1"
message_file = SimpleNamespace(
id="file-1",
type="image",
belongs_to=None,
url="tool://file.verylongextension",
message_id="msg-1",
)
session = Mock()
session.scalar.return_value = message_file
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign,
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
):
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = session
mock_sign.return_value = "signed-url"
response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1"))
assert response.url == "signed-url"
assert response.belongs_to == "user"
mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin")
def test_handle_retriever_resources_requires_features(self, message_cycle_manager):
"""Raise when retriever resources are handled without feature config.
Args: message_cycle_manager with additional_features unset and empty metadata.
Raises: ValueError when show_retrieve_source configuration is missing.
"""
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None)
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata())
with pytest.raises(ValueError):
message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[]))
def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager):
"""Ignore null resource entries while preserving valid resources."""
message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(
additional_features=SimpleNamespace(show_retrieve_source=True)
)
message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[]))
resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1")
message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource]))
assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1
assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1
def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager):
"""Use original URL when message file URL is already HTTP."""
message_cycle_manager._application_generate_entity.task_id = "task-http"
message_file = SimpleNamespace(
id="file-http",
type="image",
belongs_to="assistant",
url="http://example.com/pic.png",
message_id="msg-http",
)
session = Mock()
session.scalar.return_value = message_file
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
):
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = session
response = message_cycle_manager.message_file_to_stream_response(
SimpleNamespace(message_file_id="file-http")
)
assert response is not None
assert response.url == "http://example.com/pic.png"
assert "msg-http" in message_cycle_manager._message_has_file
def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager):
"""Default tool file extension to .bin when URL has no extension part."""
message_cycle_manager._application_generate_entity.task_id = "task-bin"
message_file = SimpleNamespace(
id="file-bin",
type="file",
belongs_to="assistant",
url="tool-file-id",
message_id="msg-bin",
)
session = Mock()
session.scalar.return_value = message_file
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign,
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
):
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = session
mock_sign.return_value = "signed-bin-url"
response = message_cycle_manager.message_file_to_stream_response(
SimpleNamespace(message_file_id="file-bin")
)
assert response is not None
assert response.url == "signed-bin-url"
mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin")
def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager):
"""Return None when message file lookup does not find a record."""
session = Mock()
session.scalar.return_value = None
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls,
patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db,
):
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = session
response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing"))
assert response is None
def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager):
"""Include the provided replacement reason in the stream payload."""
response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation")
assert response.answer == "replaced"
assert response.reason == "moderation"

View File

@ -0,0 +1,43 @@
from unittest.mock import patch
from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime
class TestDifyWorkflowFileRuntime:
def test_runtime_properties_and_helpers(self, monkeypatch):
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files")
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal")
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret")
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123)
monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url")
runtime = DifyWorkflowFileRuntime()
assert runtime.files_url == "http://files"
assert runtime.internal_files_url == "http://internal"
assert runtime.secret_key == "secret"
assert runtime.files_access_timeout == 123
assert runtime.multimodal_send_format == "url"
with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get:
mock_get.return_value = "response"
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)
with patch("core.app.workflow.file_runtime.storage.load") as mock_load:
mock_load.return_value = b"data"
assert runtime.storage_load("path", stream=True) == b"data"
mock_load.assert_called_once_with("path", stream=True)
with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign:
mock_sign.return_value = "signed"
assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed"
mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False)
def test_bind_runtime_registers_instance(self):
with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set:
bind_dify_workflow_file_runtime()
mock_set.assert_called_once()
runtime = mock_set.call_args[0][0]
assert isinstance(runtime, DifyWorkflowFileRuntime)

View File

@ -0,0 +1,161 @@
from types import SimpleNamespace
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
self.id = id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self.kwargs = kwargs
class DummyCodeNode(DummyNode):
@classmethod
def default_code_providers(cls):
return ()
class DummyTemplateTransformNode(DummyNode):
pass
class DummyHttpRequestNode(DummyNode):
pass
class DummyKnowledgeRetrievalNode(DummyNode):
pass
class DummyDocumentExtractorNode(DummyNode):
pass
class TestDifyNodeFactory:
@staticmethod
def _stub_node_resolution(monkeypatch, node_class):
monkeypatch.setattr(
"core.workflow.node_factory.resolve_workflow_node_class",
lambda **_kwargs: node_class,
)
def _factory(self, monkeypatch):
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2)
monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2)
monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100)
monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u")
monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key")
run_context = build_dify_run_context(
tenant_id="tenant",
app_id="app",
user_id="user",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
return DifyNodeFactory(
graph_init_params=SimpleNamespace(run_context=run_context),
graph_runtime_state=SimpleNamespace(),
)
def test_create_node_unknown_type(self, monkeypatch):
factory = self._factory(monkeypatch)
with pytest.raises(ValueError):
factory.create_node({"id": "node-1", "data": {"type": "unknown"}})
def test_create_node_missing_mapping(self, monkeypatch):
factory = self._factory(monkeypatch)
monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {})
with pytest.raises(ValueError):
factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}})
def test_create_node_missing_latest_class(self, monkeypatch):
factory = self._factory(monkeypatch)
monkeypatch.setattr(
"core.workflow.node_factory.get_node_type_classes_mapping",
lambda: {BuiltinNodeTypes.START: {"1": None}},
)
monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest")
with pytest.raises(ValueError):
factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}})
def test_create_node_selects_versioned_class(self, monkeypatch):
factory = self._factory(monkeypatch)
selected_versions: list[tuple[str, str]] = []
class DummyNodeV2(DummyNode):
pass
def _get_mapping():
selected_versions.append(("snapshot", "called"))
return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}}
monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}})
assert isinstance(node, DummyNodeV2)
assert node.id == "node-1"
assert selected_versions == [("snapshot", "called")]
def test_create_node_code_branch(self, monkeypatch):
factory = self._factory(monkeypatch)
self._stub_node_resolution(monkeypatch, DummyCodeNode)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}})
assert isinstance(node, DummyCodeNode)
assert node.id == "node-1"
def test_create_node_template_transform_branch(self, monkeypatch):
factory = self._factory(monkeypatch)
self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}})
assert isinstance(node, DummyTemplateTransformNode)
assert "template_renderer" in node.kwargs
def test_create_node_http_request_branch(self, monkeypatch):
factory = self._factory(monkeypatch)
self._stub_node_resolution(monkeypatch, DummyHttpRequestNode)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}})
assert isinstance(node, DummyHttpRequestNode)
assert "http_request_config" in node.kwargs
def test_create_node_knowledge_retrieval_branch(self, monkeypatch):
factory = self._factory(monkeypatch)
self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}})
assert isinstance(node, DummyKnowledgeRetrievalNode)
assert node.kwargs == {}
def test_create_node_document_extractor_branch(self, monkeypatch):
factory = self._factory(monkeypatch)
self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode)
node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}})
assert isinstance(node, DummyDocumentExtractorNode)
assert "unstructured_api_config" in node.kwargs

View File

@ -0,0 +1,209 @@
from __future__ import annotations
from types import SimpleNamespace
from core.app.workflow.layers.observability import ObservabilityLayer
from dify_graph.enums import BuiltinNodeTypes
class TestObservabilityLayerExtras:
def test_init_tracer_enabled_sets_tracer(self, monkeypatch):
tracer = object()
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer)
layer = ObservabilityLayer()
assert layer._is_disabled is False
assert layer._tracer is tracer
def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog):
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True)
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
def _raise(*_args, **_kwargs):
raise RuntimeError("tracer init failed")
monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise)
layer = ObservabilityLayer()
assert layer._is_disabled is True
assert layer._tracer is None
assert "Failed to get OpenTelemetry tracer" in caplog.text
def test_init_tracer_disables_when_otel_disabled(self, monkeypatch):
monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False)
monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False)
layer = ObservabilityLayer()
assert layer._is_disabled is True
def test_get_parser_uses_registry_when_node_type_matches(self):
layer = ObservabilityLayer()
parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL))
assert parser is layer._parsers[BuiltinNodeTypes.TOOL]
def test_get_parser_defaults_when_node_type_missing(self):
layer = ObservabilityLayer()
parser = layer._get_parser(SimpleNamespace(node_type=None))
assert parser is layer._default_parser
def test_on_graph_start_clears_contexts(self):
layer = ObservabilityLayer()
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
layer.on_graph_start()
assert layer._node_contexts == {}
def test_on_event_is_noop(self):
layer = ObservabilityLayer()
layer.on_event(object())
def test_on_graph_end_clears_unfinished_contexts(self, caplog):
layer = ObservabilityLayer()
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
layer.on_graph_end(error=None)
assert layer._node_contexts == {}
assert "node spans were not properly ended" in caplog.text
def test_on_node_run_start_skips_without_execution_id(self):
layer = ObservabilityLayer()
layer._is_disabled = False
layer._tracer = None
layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node"))
assert layer._node_contexts == {}
def test_on_node_run_start_skips_when_disabled(self):
layer = ObservabilityLayer()
layer._is_disabled = True
layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object())
layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node"))
assert layer._node_contexts == {}
def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self):
layer = ObservabilityLayer()
layer._is_disabled = False
calls: list[str] = []
layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called"))
layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node"))
assert calls == []
def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog):
layer = ObservabilityLayer()
layer._is_disabled = False
def _raise(*_args, **_kwargs):
raise RuntimeError("start failed")
layer._tracer = SimpleNamespace(start_span=_raise)
layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node"))
assert "Failed to create OpenTelemetry span for node" in caplog.text
def test_on_node_run_end_without_context_noop(self):
layer = ObservabilityLayer()
layer._is_disabled = False
layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None)
assert layer._node_contexts == {}
def test_on_node_run_end_skips_when_disabled(self):
layer = ObservabilityLayer()
layer._is_disabled = True
layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token")
layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None)
assert "exec" in layer._node_contexts
def test_on_node_run_end_skips_without_execution_id(self):
layer = ObservabilityLayer()
layer._is_disabled = False
layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None)
assert layer._node_contexts == {}
def test_on_node_run_end_calls_span_end(self, monkeypatch):
layer = ObservabilityLayer()
layer._is_disabled = False
ended: list[str] = []
class _Parser:
def parse(self, **_kwargs):
return None
span = SimpleNamespace(end=lambda: ended.append("ended"))
layer._default_parser = _Parser()
layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token")
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None)
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
layer.on_node_run_end(node, error=None)
assert ended == ["ended"]
assert "exec" not in layer._node_contexts
def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog):
layer = ObservabilityLayer()
layer._is_disabled = False
class _Parser:
def parse(self, **_kwargs):
return None
layer._default_parser = _Parser()
layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token")
def _raise(*_args, **_kwargs):
raise RuntimeError("detach failed")
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise)
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
layer.on_node_run_end(node, error=None)
assert "Failed to detach OpenTelemetry token" in caplog.text
assert "exec" not in layer._node_contexts
def test_on_node_run_start_and_end_creates_span(self, monkeypatch):
layer = ObservabilityLayer()
layer._is_disabled = False
span = SimpleNamespace(end=lambda: None)
tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span)
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object())
monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object())
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token")
monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None)
layer._tracer = tracer
node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None)
layer.on_node_run_start(node)
assert "exec" in layer._node_contexts
layer.on_node_run_end(node, error=None)
assert "exec" not in layer._node_contexts

View File

@ -0,0 +1,499 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from dify_graph.entities.pause_reason import SchedulingPause
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
from dify_graph.enums import (
BuiltinNodeTypes,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from dify_graph.graph_events.graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from dify_graph.graph_events.node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from dify_graph.node_events import NodeRunResult
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
from dify_graph.system_variable import SystemVariable
class _RepoRecorder:
def __init__(self) -> None:
self.saved: list[object] = []
self.saved_exec_data: list[object] = []
def save(self, entity):
self.saved.append(entity)
def save_execution_data(self, entity):
self.saved_exec_data.append(entity)
def _naive_utc_now() -> datetime:
return datetime.now(UTC).replace(tzinfo=None)
def _make_layer(
system_variable: SystemVariable | None = None,
*,
extras: dict | None = None,
trace_manager: object | None = None,
):
system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id")
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0)
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=SimpleNamespace(app_id="app", tenant_id="tenant"),
inputs={"foo": "bar"},
files=[],
user_id="user",
stream=False,
invoke_from=None,
trace_manager=None,
workflow_execution_id="run-id",
extras=extras or {},
call_depth=0,
)
workflow_info = PersistenceWorkflowInfo(
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
version="1",
graph_data={"nodes": [], "edges": []},
)
workflow_execution_repo = _RepoRecorder()
workflow_node_execution_repo = _RepoRecorder()
layer = WorkflowPersistenceLayer(
application_generate_entity=application_generate_entity,
workflow_info=workflow_info,
workflow_execution_repository=workflow_execution_repo,
workflow_node_execution_repository=workflow_node_execution_repo,
trace_manager=trace_manager,
)
layer.initialize(read_only_state, command_channel=None)
return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state
class TestWorkflowPersistenceLayer:
def test_on_graph_start_resets_state(self):
layer, _, _, _ = _make_layer()
layer._workflow_execution = object()
layer._node_execution_cache["cached"] = object()
layer._node_snapshots["cached"] = object()
layer._node_sequence = 9
layer.on_graph_start()
assert layer._workflow_execution is None
assert layer._node_execution_cache == {}
assert layer._node_snapshots == {}
assert layer._node_sequence == 0
def test_get_execution_id_requires_system_variable(self):
system_variable = SystemVariable(workflow_execution_id=None)
layer, _, _, _ = _make_layer(system_variable)
with pytest.raises(ValueError, match="workflow_execution_id must be provided"):
layer._get_execution_id()
def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch):
layer, _, _, _ = _make_layer()
monkeypatch.setattr(
"core.workflow.workflow_entry.WorkflowEntry.handle_special_values",
lambda inputs: inputs,
)
inputs = layer._prepare_workflow_inputs()
assert "sys.conversation_id" not in inputs
assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id"
def test_fail_running_node_executions_marks_failed(self):
layer, _, node_repo, _ = _make_layer()
execution = WorkflowNodeExecution(
id="exec-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
index=1,
node_id="node",
node_type=BuiltinNodeTypes.START,
title="Start",
created_at=_naive_utc_now(),
)
layer._node_execution_cache[execution.id] = execution
layer._fail_running_node_executions(error_message="boom")
assert execution.status == WorkflowNodeExecutionStatus.FAILED
assert node_repo.saved
def test_handle_graph_run_started_saves_execution(self):
layer, exec_repo, _, _ = _make_layer()
layer._handle_graph_run_started()
assert exec_repo.saved
def test_handle_graph_run_succeeded_updates_execution(self):
layer, exec_repo, _, runtime_state = _make_layer()
layer._handle_graph_run_started()
runtime_state.total_tokens = 3
runtime_state.node_run_steps = 2
runtime_state.outputs = {"out": "v"}
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
saved = exec_repo.saved[-1]
assert saved.status == WorkflowExecutionStatus.SUCCEEDED
assert saved.total_tokens == 3
assert saved.total_steps == 2
def test_handle_graph_run_partial_succeeded_updates_execution(self):
layer, exec_repo, _, runtime_state = _make_layer()
layer._handle_graph_run_started()
runtime_state.total_tokens = 5
runtime_state.node_run_steps = 4
runtime_state._graph_execution = SimpleNamespace(exceptions_count=2)
layer._handle_graph_run_partial_succeeded(
GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2)
)
saved = exec_repo.saved[-1]
assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
assert saved.exceptions_count == 2
assert saved.total_tokens == 5
def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self):
trace_tasks: list[object] = []
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager)
layer._handle_graph_run_started()
running = WorkflowNodeExecution(
id="node-exec",
workflow_id="workflow-id",
workflow_execution_id="run-id",
index=1,
node_id="node",
node_type=BuiltinNodeTypes.START,
title="Start",
created_at=_naive_utc_now(),
)
layer._node_execution_cache[running.id] = running
layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1))
assert node_repo.saved
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
assert trace_tasks
def test_handle_graph_run_aborted_sets_status(self):
layer, exec_repo, _, _ = _make_layer()
layer._handle_graph_run_started()
layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={}))
saved = exec_repo.saved[-1]
assert saved.status == WorkflowExecutionStatus.STOPPED
assert saved.error_message
def test_handle_graph_run_paused_updates_outputs(self):
layer, exec_repo, _, runtime_state = _make_layer()
layer._handle_graph_run_started()
runtime_state.total_tokens = 7
runtime_state.node_run_steps = 5
layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True}))
saved = exec_repo.saved[-1]
assert saved.status == WorkflowExecutionStatus.PAUSED
assert saved.outputs == {"pause": True}
assert saved.finished_at is None
def test_handle_node_started_and_retry(self):
layer, _, node_repo, _ = _make_layer()
layer._handle_graph_run_started()
start_event = NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=_naive_utc_now(),
predecessor_node_id="prev",
in_iteration_id="iter",
in_loop_id="loop",
)
layer._handle_node_started(start_event)
assert node_repo.saved
assert "exec" in layer._node_execution_cache
assert layer._node_snapshots["exec"].node_id == "node"
retry_event = NodeRunRetryEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=_naive_utc_now(),
error="retry",
retry_index=1,
)
layer._handle_node_retry(retry_event)
assert node_repo.saved_exec_data
def test_handle_node_result_events_update_execution(self):
layer, _, node_repo, _ = _make_layer()
layer._handle_graph_run_started()
start_event = NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=_naive_utc_now(),
)
layer._handle_node_started(start_event)
result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={})
success_event = NodeRunSucceededEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
start_at=_naive_utc_now(),
node_run_result=result,
)
layer._handle_node_succeeded(success_event)
failed_event = NodeRunFailedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
start_at=_naive_utc_now(),
error="boom",
node_run_result=result,
)
layer._handle_node_failed(failed_event)
exception_event = NodeRunExceptionEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
start_at=_naive_utc_now(),
error="err",
node_run_result=result,
)
layer._handle_node_exception(exception_event)
assert node_repo.saved_exec_data
def test_handle_node_pause_requested_skips_outputs(self):
layer, _, _, _ = _make_layer()
layer._handle_graph_run_started()
start_event = NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=_naive_utc_now(),
)
layer._handle_node_started(start_event)
domain_execution = layer._node_execution_cache["exec"]
domain_execution.inputs = {"old": True}
result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={})
pause_event = NodeRunPauseRequestedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
reason=SchedulingPause(message="pause"),
node_run_result=result,
)
layer._handle_node_pause_requested(pause_event)
assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED
assert domain_execution.inputs == {"old": True}
def test_get_node_execution_raises_for_missing(self):
layer, _, _, _ = _make_layer()
with pytest.raises(ValueError, match="Node execution not found"):
layer._get_node_execution("missing")
def test_get_workflow_execution_raises_when_uninitialized(self):
layer, _, _, _ = _make_layer()
with pytest.raises(ValueError, match="workflow execution not initialized"):
layer._get_workflow_execution()
def test_next_node_sequence_increments(self):
layer, _, _, _ = _make_layer()
assert layer._next_node_sequence() == 1
assert layer._next_node_sequence() == 2
def test_on_graph_end_is_noop(self):
layer, _, _, _ = _make_layer()
assert layer.on_graph_end(error=None) is None
def test_on_event_dispatches_to_all_known_handlers(self):
layer, _, _, _ = _make_layer()
called: list[str] = []
def _record(name: str):
def _handler(*_args, **_kwargs):
called.append(name)
return _handler
layer._handle_graph_run_started = _record("started")
layer._handle_graph_run_succeeded = _record("succeeded")
layer._handle_graph_run_partial_succeeded = _record("partial")
layer._handle_graph_run_failed = _record("failed")
layer._handle_graph_run_aborted = _record("aborted")
layer._handle_graph_run_paused = _record("paused")
layer._handle_node_started = _record("node_started")
layer._handle_node_retry = _record("node_retry")
layer._handle_node_succeeded = _record("node_succeeded")
layer._handle_node_failed = _record("node_failed")
layer._handle_node_exception = _record("node_exception")
layer._handle_node_pause_requested = _record("node_paused")
node_result = NodeRunResult()
now = _naive_utc_now()
events = [
GraphRunStartedEvent(),
GraphRunSucceededEvent(outputs={"ok": True}),
GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1),
GraphRunFailedEvent(error="boom", exceptions_count=1),
GraphRunAbortedEvent(reason="stop", outputs={"x": 1}),
GraphRunPausedEvent(outputs={"pause": True}),
NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=now,
),
NodeRunRetryEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=now,
error="retry",
retry_index=1,
),
NodeRunSucceededEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
start_at=now,
node_run_result=node_result,
),
NodeRunFailedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
start_at=now,
error="failed",
node_run_result=node_result,
),
NodeRunExceptionEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
start_at=now,
error="error",
node_run_result=node_result,
),
NodeRunPauseRequestedEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
reason=SchedulingPause(message="pause"),
node_run_result=node_result,
),
]
expected_order = [
"started",
"succeeded",
"partial",
"failed",
"aborted",
"paused",
"node_started",
"node_retry",
"node_succeeded",
"node_failed",
"node_exception",
"node_paused",
]
for event in events:
layer.on_event(event)
assert called == expected_order
def test_on_event_dispatches_retry_before_started_for_retry_event(self):
layer, _, _, _ = _make_layer()
called: list[str] = []
def _record(name: str):
def _handler(*_args, **_kwargs):
called.append(name)
return _handler
layer._handle_node_started = _record("node_started")
layer._handle_node_retry = _record("node_retry")
layer.on_event(
NodeRunRetryEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.START,
node_title="Start",
start_at=_naive_utc_now(),
error="retry",
retry_index=1,
)
)
assert called == ["node_retry"]
def test_enqueue_trace_task_skips_when_disabled(self):
trace_tasks: list[object] = []
layer, exec_repo, _, _ = _make_layer()
layer._handle_graph_run_started()
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
assert exec_repo.saved
assert not trace_tasks

View File

@ -110,7 +110,7 @@ def test_encrypt_tool_parameters():
assert encrypted["plain"] == "x"
def test_decrypt_tool_parameters_cache_hit_and_miss():
def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch):
manager = _build_manager()
with (
@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache():
mock_delete.assert_called_once()
def test_configuration_manager_decrypt_suppresses_errors():
def test_configuration_manager_decrypt_suppresses_errors(monkeypatch):
manager = _build_manager()
with (
patch.object(ToolParameterCache, "get", return_value=None),