mirror of https://github.com/langgenius/dify.git
feat: enhance typing
This commit is contained in:
parent
bfbb36756a
commit
3dee8064ba
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class NodeFactory(Protocol):
|
|||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ class NodeFactory(Protocol):
|
|||
...
|
||||
|
||||
|
||||
@final
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
|
|
@ -58,18 +60,18 @@ class Graph:
|
|||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, Any]] = {}
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
|
|
@ -79,8 +81,8 @@ class Graph:
|
|||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
edge_configs: list[dict[str, Any]],
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -97,10 +99,10 @@ class Graph:
|
|||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming = set()
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
|
@ -108,8 +110,13 @@ class Graph:
|
|||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data", {})
|
||||
if node_data.get("type") == NodeType.START.value:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
|
|
@ -122,7 +129,7 @@ class Graph:
|
|||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, Any]]
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
|
@ -139,7 +146,7 @@ class Graph:
|
|||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not source or not target:
|
||||
if not is_str(source) or not is_str(target):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
|
|
@ -147,6 +154,8 @@ class Graph:
|
|||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not is_str(source_handle):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
|
|
@ -164,7 +173,7 @@ class Graph:
|
|||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
|
|
@ -256,7 +265,7 @@ class Graph:
|
|||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
|
|
@ -272,10 +281,12 @@ class Graph:
|
|||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
|
@ -10,6 +13,7 @@ if TYPE_CHECKING:
|
|||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
|
@ -26,10 +30,8 @@ class DifyNodeFactory(NodeFactory):
|
|||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
def create_node(
|
||||
self,
|
||||
node_config: dict[str, Any],
|
||||
) -> Node:
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
|
|
@ -39,11 +41,14 @@ class DifyNodeFactory(NodeFactory):
|
|||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not node_type_str:
|
||||
raise ValueError(f"Node {node_id} missing type information")
|
||||
|
|
@ -72,6 +77,8 @@ class DifyNodeFactory(NodeFactory):
|
|||
|
||||
# Initialize node with provided data
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
from typing import TypeGuard
|
||||
|
||||
|
||||
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
|
||||
return isinstance(v, dict)
|
||||
|
||||
|
||||
def is_str(v: object) -> TypeGuard[str]:
|
||||
return isinstance(v, str)
|
||||
Loading…
Reference in New Issue