diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8ceaa428c3..5e4fada7d9 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -10,10 +10,10 @@ from __future__ import annotations from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode +from core.workflow.runtime.graph_runtime_state import NodeProtocol @dataclass @@ -29,21 +29,26 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> ResponseSession: + def from_node(cls, node: NodeProtocol) -> ResponseSession: """ - Create a ResponseSession from an AnswerNode or EndNode. + Create a ResponseSession from a response-capable node. + + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, + but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: + - `id: str` + - `get_streaming_template() -> Template` Args: - node: Must be either an AnswerNode or EndNode instance + node: Node from the materialized workflow graph. Returns: ResponseSession configured with the node's streaming template Raises: - TypeError: If node is not an AnswerNode or EndNode + TypeError: If node is not a supported response node type. """ if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError + raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") return cls( node_id=node.id, template=node.get_streaming_template(), diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 401cecc162..acf0ee6839 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,12 +6,13 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Protocol +from typing import Any, ClassVar, Protocol from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.runtime.variable_pool import VariablePool @@ -103,14 +104,33 @@ class ResponseStreamCoordinatorProtocol(Protocol): ... +class NodeProtocol(Protocol): + """Structural interface for graph nodes.""" + + id: str + state: NodeState + execution_type: NodeExecutionType + node_type: ClassVar[NodeType] + + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... + + +class EdgeProtocol(Protocol): + id: str + state: NodeState + tail: str + head: str + source_handle: str + + class GraphProtocol(Protocol): """Structural interface required from graph instances attached to the runtime state.""" - nodes: Mapping[str, object] - edges: Mapping[str, object] - root_node: object + nodes: Mapping[str, NodeProtocol] + edges: Mapping[str, EdgeProtocol] + root_node: NodeProtocol - def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... + def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... @dataclass(slots=True)