From 3dee8064badd64f195902feb7c749384a947c7d7 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 29 Aug 2025 13:17:02 +0800 Subject: [PATCH] feat: enhance typing --- api/core/workflow/graph/graph.py | 45 +++++++++++++++---------- api/core/workflow/nodes/node_factory.py | 19 +++++++---- api/libs/typing.py | 9 +++++ 3 files changed, 50 insertions(+), 23 deletions(-) create mode 100644 api/libs/typing.py diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 3c21b4659f..dc38d4d2a3 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,10 +1,11 @@ import logging from collections import defaultdict -from collections.abc import Mapping -from typing import Any, Protocol, cast +from collections.abc import Mapping, Sequence +from typing import Protocol, cast, final from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node +from libs.typing import is_str, is_str_dict from .edge import Edge @@ -19,7 +20,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ - def create_node(self, node_config: dict[str, Any]) -> Node: + def create_node(self, node_config: dict[str, object]) -> Node: """ Create a Node instance from node configuration data. @@ -30,6 +31,7 @@ class NodeFactory(Protocol): ... +@final class Graph: """Graph representation with nodes and edges for workflow execution.""" @@ -58,18 +60,18 @@ class Graph: self.root_node = root_node @classmethod - def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]: """ Parse node configurations and build a mapping of node IDs to configs. :param node_configs: list of node configuration dictionaries :return: mapping of node ID to node config """ - node_configs_map: dict[str, dict[str, Any]] = {} + node_configs_map: dict[str, dict[str, object]] = {} for node_config in node_configs: node_id = node_config.get("id") - if not node_id: + if not node_id or not isinstance(node_id, str): continue node_configs_map[node_id] = node_config @@ -79,8 +81,8 @@ class Graph: @classmethod def _find_root_node_id( cls, - node_configs_map: dict[str, dict[str, Any]], - edge_configs: list[dict[str, Any]], + node_configs_map: Mapping[str, Mapping[str, object]], + edge_configs: Sequence[Mapping[str, object]], root_node_id: str | None = None, ) -> str: """ @@ -97,10 +99,10 @@ class Graph: return root_node_id # Find nodes with no incoming edges - nodes_with_incoming = set() + nodes_with_incoming: set[str] = set() for edge_config in edge_configs: target = edge_config.get("target") - if target: + if isinstance(target, str): nodes_with_incoming.add(target) root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming] @@ -108,8 +110,13 @@ class Graph: # Prefer START node if available start_node_id = None for nid in root_candidates: - node_data = node_configs_map[nid].get("data", {}) - if node_data.get("type") == NodeType.START.value: + node_data = node_configs_map[nid].get("data") + if not is_str_dict(node_data): + continue + node_type = node_data.get("type") + if not isinstance(node_type, str): + continue + if node_type == NodeType.START: start_node_id = nid break @@ -122,7 +129,7 @@ class Graph: @classmethod def _build_edges( - cls, edge_configs: list[dict[str, Any]] + cls, edge_configs: list[dict[str, object]] ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: """ Build edge objects and mappings from edge configurations. @@ -139,7 +146,7 @@ class Graph: source = edge_config.get("source") target = edge_config.get("target") - if not source or not target: + if not is_str(source) or not is_str(target): continue # Create edge @@ -147,6 +154,8 @@ class Graph: edge_counter += 1 source_handle = edge_config.get("sourceHandle", "source") + if not is_str(source_handle): + continue edge = Edge( id=edge_id, @@ -164,7 +173,7 @@ class Graph: @classmethod def _create_node_instances( cls, - node_configs_map: dict[str, dict[str, Any]], + node_configs_map: dict[str, dict[str, object]], node_factory: "NodeFactory", ) -> dict[str, Node]: """ @@ -256,7 +265,7 @@ class Graph: def init( cls, *, - graph_config: Mapping[str, Any], + graph_config: Mapping[str, object], node_factory: "NodeFactory", root_node_id: str | None = None, ) -> "Graph": @@ -272,10 +281,12 @@ class Graph: edge_configs = graph_config.get("edges", []) node_configs = graph_config.get("nodes", []) + edge_configs = cast(list[dict[str, object]], edge_configs) + node_configs = cast(list[dict[str, object]], node_configs) + if not node_configs: raise ValueError("Graph must have at least one node") - edge_configs = cast(list, edge_configs) node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"] # Parse node configurations diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index bf6a1389fc..5ded1ad44c 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -1,8 +1,11 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, final + +from typing_extensions import override from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node +from libs.typing import is_str, is_str_dict from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -10,6 +13,7 @@ if TYPE_CHECKING: from core.workflow.entities import GraphInitParams, GraphRuntimeState +@final class DifyNodeFactory(NodeFactory): """ Default implementation of NodeFactory that uses the traditional node mapping. @@ -26,10 +30,8 @@ class DifyNodeFactory(NodeFactory): self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state - def create_node( - self, - node_config: dict[str, Any], - ) -> Node: + @override + def create_node(self, node_config: dict[str, object]) -> Node: """ Create a Node instance from node configuration data using the traditional mapping. @@ -39,11 +41,14 @@ class DifyNodeFactory(NodeFactory): """ # Get node_id from config node_id = node_config.get("id") - if not node_id: + if not is_str(node_id): raise ValueError("Node config missing id") # Get node type from config node_data = node_config.get("data", {}) + if not is_str_dict(node_data): + raise ValueError(f"Node {node_id} missing data information") + node_type_str = node_data.get("type") if not node_type_str: raise ValueError(f"Node {node_id} missing type information") @@ -72,6 +77,8 @@ class DifyNodeFactory(NodeFactory): # Initialize node with provided data node_data = node_config.get("data", {}) + if not is_str_dict(node_data): + raise ValueError(f"Node {node_id} missing data information") node_instance.init_node_data(node_data) # If node has fail branch, change execution type to branch diff --git a/api/libs/typing.py b/api/libs/typing.py new file mode 100644 index 0000000000..f84e9911e0 --- /dev/null +++ b/api/libs/typing.py @@ -0,0 +1,9 @@ +from typing import TypeGuard + + +def is_str_dict(v: object) -> TypeGuard[dict[str, object]]: + return isinstance(v, dict) + + +def is_str(v: object) -> TypeGuard[str]: + return isinstance(v, str)