diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..185f0ad620 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,3 +1,5 @@ +from ..runtime.graph_runtime_state import GraphRuntimeState +from ..runtime.variable_pool import VariablePool from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution @@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "GraphRuntimeState", + "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 20b5193875..d04724425c 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -3,11 +3,12 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final -from core.workflow.enums import NodeExecutionType, NodeState, NodeType +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict from .edge import Edge +from .validation import get_graph_validator logger = logging.getLogger(__name__) @@ -201,6 +202,17 @@ class Graph: return GraphBuilder(graph_cls=cls) + @classmethod + def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: + """ + Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. + + :param nodes: mapping of node ID to node instance + """ + for node in nodes.values(): + if node.error_strategy == ErrorStrategy.FAIL_BRANCH: + node.execution_type = NodeExecutionType.BRANCH + @classmethod def _mark_inactive_root_branches( cls, @@ -307,6 +319,9 @@ class Graph: # Create node instances nodes = cls._create_node_instances(node_configs_map, node_factory) + # Promote fail-branch nodes to branch execution type at graph level + cls._promote_fail_branch_nodes(nodes) + # Get root node instance root_node = nodes[root_node_id] @@ -314,7 +329,7 @@ class Graph: cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) # Create and return the graph - return cls( + graph = cls( nodes=nodes, edges=edges, in_edges=in_edges, @@ -322,6 +337,11 @@ class Graph: root_node=root_node, ) + # Validate the graph structure using built-in validators + get_graph_validator().validate(graph) + + return graph + @property def node_ids(self) -> list[str]: """ diff --git a/api/core/workflow/graph/validation.py b/api/core/workflow/graph/validation.py new file mode 100644 index 0000000000..87aa7db2e4 --- /dev/null +++ b/api/core/workflow/graph/validation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol + +from core.workflow.enums import NodeExecutionType, NodeType + +if TYPE_CHECKING: + from .graph import Graph + + +@dataclass(frozen=True, slots=True) +class GraphValidationIssue: + """Immutable value object describing a single validation issue.""" + + code: str + message: str + node_id: str | None = None + + +class GraphValidationError(ValueError): + """Raised when graph validation fails.""" + + def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: + if not issues: + raise ValueError("GraphValidationError requires at least one issue.") + self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) + message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) + super().__init__(message) + + +class GraphValidationRule(Protocol): + """Protocol that individual validation rules must satisfy.""" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + """Validate the provided graph and return any discovered issues.""" + ... + + +@dataclass(frozen=True, slots=True) +class _EdgeEndpointValidator: + """Ensures all edges reference existing nodes.""" + + missing_node_code: str = "MISSING_NODE" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + issues: list[GraphValidationIssue] = [] + for edge in graph.edges.values(): + if edge.tail not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", + node_id=edge.tail, + ) + ) + if edge.head not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown target node '{edge.head}'.", + node_id=edge.head, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class _RootNodeValidator: + """Validates root node invariants.""" + + invalid_root_code: str = "INVALID_ROOT" + container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + root_node = graph.root_node + issues: list[GraphValidationIssue] = [] + if root_node.id not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' is missing from the node registry.", + node_id=root_node.id, + ) + ) + return issues + + node_type = getattr(root_node, "node_type", None) + if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' must declare execution type 'root'.", + node_id=root_node.id, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class GraphValidator: + """Coordinates execution of graph validation rules.""" + + rules: tuple[GraphValidationRule, ...] + + def validate(self, graph: Graph) -> None: + """Validate the graph against all configured rules.""" + issues: list[GraphValidationIssue] = [] + for rule in self.rules: + issues.extend(rule.validate(graph)) + + if issues: + raise GraphValidationError(issues) + + +_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( + _EdgeEndpointValidator(), + _RootNodeValidator(), +) + + +def get_graph_validator() -> GraphValidator: + """Construct the validator composed of default rules.""" + return GraphValidator(_DEFAULT_RULES) diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 87d1b8c435..84f63d57eb 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict @@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"Node {node_id} missing data information") node_instance.init_node_data(node_data) - # If node has fail branch, change execution type to branch - if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH: - node_instance.execution_type = NodeExecutionType.BRANCH - return node_instance diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py new file mode 100644 index 0000000000..b55d4998c4 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import time +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.graph import Graph +from core.workflow.graph.validation import GraphValidationError +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +class _TestNode(Node): + node_type = NodeType.ANSWER + execution_type = NodeExecutionType.EXECUTABLE + + @classmethod + def version(cls) -> str: + return "test" + + def __init__( + self, + *, + id: str, + config: Mapping[str, object], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + data = config.get("data", {}) + if isinstance(data, Mapping): + execution_type = data.get("execution_type") + if isinstance(execution_type, str): + self.execution_type = NodeExecutionType(execution_type) + self._base_node_data = BaseNodeData(title=str(data.get("title", self.id))) + self.data: dict[str, object] = {} + + def init_node_data(self, data: Mapping[str, object]) -> None: + title = str(data.get("title", self.id)) + desc = data.get("description") + error_strategy_value = data.get("error_strategy") + error_strategy: ErrorStrategy | None = None + if isinstance(error_strategy_value, ErrorStrategy): + error_strategy = error_strategy_value + elif isinstance(error_strategy_value, str): + error_strategy = ErrorStrategy(error_strategy_value) + self._base_node_data = BaseNodeData( + title=title, + desc=str(desc) if desc is not None else None, + error_strategy=error_strategy, + ) + self.data = dict(data) + + def _run(self): + raise NotImplementedError + + def _get_error_strategy(self) -> ErrorStrategy | None: + return self._base_node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._base_node_data.retry_config + + def _get_title(self) -> str: + return self._base_node_data.title + + def _get_description(self) -> str | None: + return self._base_node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._base_node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._base_node_data + + +@dataclass(slots=True) +class _SimpleNodeFactory: + graph_init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState + + def create_node(self, node_config: Mapping[str, object]) -> _TestNode: + node_id = str(node_config["id"]) + node = _TestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + node.init_node_data(node_config.get("data", {})) + return node + + +@pytest.fixture +def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: + graph_config: dict[str, object] = {"edges": [], "nodes": []} + init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) + return factory, graph_config + + +def test_graph_initialization_runs_default_validators( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +): + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + + +def test_graph_validation_fails_for_unknown_edge_targets( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "missing", "sourceHandle": "success"}, + ] + + with pytest.raises(GraphValidationError) as exc: + Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) + + +def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "branch", + "data": { + "type": NodeType.IF_ELSE, + "title": "Branch", + "error_strategy": ErrorStrategy.FAIL_BRANCH, + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "branch", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH