feat: enhance typing

This commit is contained in:
-LAN- 2025-08-29 13:17:02 +08:00
parent bfbb36756a
commit 3dee8064ba
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 50 additions and 23 deletions

View File

@ -1,10 +1,11 @@
import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Protocol, cast
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
@ -19,7 +20,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, Any]) -> Node:
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data.
@ -30,6 +31,7 @@ class NodeFactory(Protocol):
...
@final
class Graph:
"""Graph representation with nodes and edges for workflow execution."""
@ -58,18 +60,18 @@ class Graph:
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, Any]] = {}
node_configs_map: dict[str, dict[str, object]] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
@ -79,8 +81,8 @@ class Graph:
@classmethod
def _find_root_node_id(
cls,
node_configs_map: dict[str, dict[str, Any]],
edge_configs: list[dict[str, Any]],
node_configs_map: Mapping[str, Mapping[str, object]],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
@ -97,10 +99,10 @@ class Graph:
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming = set()
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if target:
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
@ -108,8 +110,13 @@ class Graph:
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data", {})
if node_data.get("type") == NodeType.START.value:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type == NodeType.START:
start_node_id = nid
break
@ -122,7 +129,7 @@ class Graph:
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, Any]]
cls, edge_configs: list[dict[str, object]]
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
"""
Build edge objects and mappings from edge configurations.
@ -139,7 +146,7 @@ class Graph:
source = edge_config.get("source")
target = edge_config.get("target")
if not source or not target:
if not is_str(source) or not is_str(target):
continue
# Create edge
@ -147,6 +154,8 @@ class Graph:
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not is_str(source_handle):
continue
edge = Edge(
id=edge_id,
@ -164,7 +173,7 @@ class Graph:
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, Any]],
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
) -> dict[str, Node]:
"""
@ -256,7 +265,7 @@ class Graph:
def init(
cls,
*,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
root_node_id: str | None = None,
) -> "Graph":
@ -272,10 +281,12 @@ class Graph:
edge_configs = graph_config.get("edges", [])
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations

View File

@ -1,8 +1,11 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -10,6 +13,7 @@ if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
@ -26,10 +30,8 @@ class DifyNodeFactory(NodeFactory):
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
def create_node(
self,
node_config: dict[str, Any],
) -> Node:
@override
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
@ -39,11 +41,14 @@ class DifyNodeFactory(NodeFactory):
"""
# Get node_id from config
node_id = node_config.get("id")
if not node_id:
if not is_str(node_id):
raise ValueError("Node config missing id")
# Get node type from config
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_type_str = node_data.get("type")
if not node_type_str:
raise ValueError(f"Node {node_id} missing type information")
@ -72,6 +77,8 @@ class DifyNodeFactory(NodeFactory):
# Initialize node with provided data
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch

9
api/libs/typing.py Normal file
View File

@ -0,0 +1,9 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)