mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 11:56:55 +08:00
feat(api): enhance workflow validation and structure checks
- Added a new validation class to ensure that trigger nodes do not coexist with UserInput (start) nodes in the workflow graph. - Implemented a method in WorkflowService to validate the graph structure before persisting workflows, leveraging the new validation logic. - Updated unit tests to cover the new validation scenarios and ensure proper error propagation.
This commit is contained in:
parent
7484a020e1
commit
aad31bb703
@ -43,6 +43,9 @@ class InvokeFrom(StrEnum):
|
|||||||
# the workflow (or chatflow) edit page.
|
# the workflow (or chatflow) edit page.
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
PUBLISHED = "published"
|
PUBLISHED = "published"
|
||||||
|
|
||||||
|
# VALIDATION indicates that this invocation is from validation.
|
||||||
|
VALIDATION = "validation"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str):
|
def value_of(cls, value: str):
|
||||||
|
|||||||
@ -114,9 +114,45 @@ class GraphValidator:
|
|||||||
raise GraphValidationError(issues)
|
raise GraphValidationError(issues)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _TriggerStartExclusivityValidator:
|
||||||
|
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
|
||||||
|
|
||||||
|
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
start_node_id: str | None = None
|
||||||
|
trigger_node_ids: list[str] = []
|
||||||
|
|
||||||
|
for node in graph.nodes.values():
|
||||||
|
node_type = getattr(node, "node_type", None)
|
||||||
|
if not isinstance(node_type, NodeType):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if node_type == NodeType.START:
|
||||||
|
start_node_id = node.id
|
||||||
|
elif node_type.is_trigger_node:
|
||||||
|
trigger_node_ids.append(node.id)
|
||||||
|
|
||||||
|
if start_node_id and trigger_node_ids:
|
||||||
|
trigger_list = ", ".join(trigger_node_ids)
|
||||||
|
return [
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.conflict_code,
|
||||||
|
message=(
|
||||||
|
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
|
||||||
|
),
|
||||||
|
node_id=start_node_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||||
_EdgeEndpointValidator(),
|
_EdgeEndpointValidator(),
|
||||||
_RootNodeValidator(),
|
_RootNodeValidator(),
|
||||||
|
_TriggerStartExclusivityValidator(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker
|
|||||||
from core.app.app_config.entities import VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.entities import WorkflowNodeExecution
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution
|
||||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
|
from core.workflow.graph.graph import Graph
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.nodes.start.entities import StartNodeData
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
from core.workflow.runtime import VariablePool
|
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
@ -32,6 +34,7 @@ from extensions.ext_storage import storage
|
|||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import Account
|
from models import Account
|
||||||
|
from models.enums import UserFrom
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||||
@ -211,6 +214,9 @@ class WorkflowService:
|
|||||||
# validate features structure
|
# validate features structure
|
||||||
self.validate_features_structure(app_model=app_model, features=features)
|
self.validate_features_structure(app_model=app_model, features=features)
|
||||||
|
|
||||||
|
# validate graph structure
|
||||||
|
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph)
|
||||||
|
|
||||||
# create draft workflow if not found
|
# create draft workflow if not found
|
||||||
if not workflow:
|
if not workflow:
|
||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
@ -267,6 +273,9 @@ class WorkflowService:
|
|||||||
if FeatureService.get_system_features().plugin_manager.enabled:
|
if FeatureService.get_system_features().plugin_manager.enabled:
|
||||||
self._validate_workflow_credentials(draft_workflow)
|
self._validate_workflow_credentials(draft_workflow)
|
||||||
|
|
||||||
|
# validate graph structure
|
||||||
|
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict)
|
||||||
|
|
||||||
# create new workflow
|
# create new workflow
|
||||||
workflow = Workflow.new(
|
workflow = Workflow.new(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
@ -896,6 +905,36 @@ class WorkflowService:
|
|||||||
|
|
||||||
return new_app
|
return new_app
|
||||||
|
|
||||||
|
def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate workflow graph structure by instantiating the Graph object.
|
||||||
|
|
||||||
|
This leverages the built-in graph validators (including trigger/UserInput exclusivity)
|
||||||
|
and raises any structural errors before persisting the workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Graph.init(
|
||||||
|
graph_config=graph,
|
||||||
|
# TODO(Mairuis): Add root node id
|
||||||
|
root_node_id=None,
|
||||||
|
node_factory=DifyNodeFactory(
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
workflow_id=app_model.workflow_id,
|
||||||
|
graph_config=graph,
|
||||||
|
user_id=user_id,
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.VALIDATION,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(),
|
||||||
|
start_at=time.perf_counter(),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def validate_features_structure(self, app_model: App, features: dict):
|
def validate_features_structure(self, app_model: App, features: dict):
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||||
return AdvancedChatAppConfigManager.config_validate(
|
return AdvancedChatAppConfigManager.config_validate(
|
||||||
|
|||||||
@ -64,6 +64,15 @@ class _TestNode(Node):
|
|||||||
)
|
)
|
||||||
self.data = dict(data)
|
self.data = dict(data)
|
||||||
|
|
||||||
|
node_type_value = data.get("type")
|
||||||
|
if isinstance(node_type_value, NodeType):
|
||||||
|
self.node_type = node_type_value
|
||||||
|
elif isinstance(node_type_value, str):
|
||||||
|
try:
|
||||||
|
self.node_type = NodeType(node_type_value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -179,3 +188,22 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
|
|||||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
|
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_validation_blocks_start_and_trigger_coexistence(
|
||||||
|
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||||
|
) -> None:
|
||||||
|
node_factory, graph_config = graph_init_dependencies
|
||||||
|
graph_config["nodes"] = [
|
||||||
|
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||||
|
{
|
||||||
|
"id": "trigger",
|
||||||
|
"data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
graph_config["edges"] = []
|
||||||
|
|
||||||
|
with pytest.raises(GraphValidationError) as exc_info:
|
||||||
|
Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
|
assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues)
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from core.workflow.graph.validation import GraphValidationError, GraphValidationIssue
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
@ -161,3 +162,25 @@ class TestWorkflowService:
|
|||||||
assert workflows == []
|
assert workflows == []
|
||||||
assert has_more is False
|
assert has_more is False
|
||||||
mock_session.scalars.assert_called_once()
|
mock_session.scalars.assert_called_once()
|
||||||
|
|
||||||
|
def test_validate_graph_structure_invokes_graph_init(self, workflow_service, mock_app):
|
||||||
|
graph = {"nodes": [], "edges": []}
|
||||||
|
|
||||||
|
with patch("services.workflow_service.Graph.init") as mock_graph_init:
|
||||||
|
workflow_service.validate_graph_structure(mock_app, graph)
|
||||||
|
|
||||||
|
mock_graph_init.assert_called_once()
|
||||||
|
assert mock_graph_init.call_args.kwargs["graph_config"] is graph
|
||||||
|
assert "node_factory" in mock_graph_init.call_args.kwargs
|
||||||
|
|
||||||
|
def test_validate_graph_structure_propagates_graph_errors(self, workflow_service, mock_app):
|
||||||
|
graph = {"nodes": [], "edges": []}
|
||||||
|
issue = GraphValidationIssue(code="ERR", message="invalid")
|
||||||
|
|
||||||
|
with patch("services.workflow_service.Graph.init", side_effect=GraphValidationError([issue])):
|
||||||
|
with pytest.raises(GraphValidationError):
|
||||||
|
workflow_service.validate_graph_structure(mock_app, graph)
|
||||||
|
|
||||||
|
def test_validate_graph_structure_requires_nodes_and_edges(self, workflow_service, mock_app):
|
||||||
|
with pytest.raises(ValueError, match="must include 'nodes' and 'edges'"):
|
||||||
|
workflow_service.validate_graph_structure(mock_app, {"nodes": []})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user