refactor(graph_engine): use singledispatch in Node

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-10 20:59:34 +08:00
parent f56fccee9d
commit 00a1af8506
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
5 changed files with 64 additions and 54 deletions

View File

@ -33,7 +33,13 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
@ -93,7 +99,7 @@ class AgentNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
@ -482,7 +488,7 @@ class AgentNode(Node):
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""

View File

@ -1,7 +1,8 @@
import logging
from abc import abstractmethod
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
@ -88,14 +89,14 @@ class Node:
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
@ -142,8 +143,9 @@ class Node:
# Handle event stream
for event in result:
if isinstance(event, NodeEventBase):
event = self._convert_node_event_to_graph_node_event(event)
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
event = self._dispatch(event)
if not event.in_iteration_id and not event.in_loop_id:
event.id = self._node_execution_id
@ -240,7 +242,7 @@ class Node:
return False
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {}
@classmethod
@ -261,7 +263,7 @@ class Node:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@ -276,7 +278,7 @@ class Node:
...
@abstractmethod
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
"""Get the node description."""
...
@ -292,7 +294,7 @@ class Node:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional["ErrorStrategy"]:
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -307,7 +309,7 @@ class Node:
return self._get_title()
@property
def description(self) -> Optional[str]:
def description(self) -> str | None:
"""Get the node description."""
return self._get_description()
@ -335,29 +337,15 @@ class Node:
start_at=self._start_at,
node_run_result=result,
)
raise Exception(f"result status {result.status} not supported")
case _:
raise Exception(f"result status {result.status} not supported")
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
StreamChunkEvent: self._handle_stream_chunk_event,
StreamCompletedEvent: self._handle_stream_completed_event,
AgentLogEvent: self._handle_agent_log_event,
LoopStartedEvent: self._handle_loop_started_event,
LoopNextEvent: self._handle_loop_next_event,
LoopSucceededEvent: self._handle_loop_succeeded_event,
LoopFailedEvent: self._handle_loop_failed_event,
IterationStartedEvent: self._handle_iteration_started_event,
IterationNextEvent: self._handle_iteration_next_event,
IterationSucceededEvent: self._handle_iteration_succeeded_event,
IterationFailedEvent: self._handle_iteration_failed_event,
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
}
handler = handler_maps.get(type(event))
if not handler:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
return handler(event)
@singledispatchmethod
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -367,7 +355,8 @@ class Node:
is_final=event.is_final,
)
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
@ -386,9 +375,13 @@ class Node:
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
case _:
raise NotImplementedError(
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -403,7 +396,8 @@ class Node:
metadata=event.metadata,
)
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -415,7 +409,8 @@ class Node:
predecessor_node_id=event.predecessor_node_id,
)
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -425,7 +420,8 @@ class Node:
pre_loop_output=event.pre_loop_output,
)
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -438,7 +434,8 @@ class Node:
steps=event.steps,
)
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -452,7 +449,8 @@ class Node:
error=event.error,
)
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -464,7 +462,8 @@ class Node:
predecessor_node_id=event.predecessor_node_id,
)
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -474,7 +473,8 @@ class Node:
pre_iteration_output=event.pre_iteration_output,
)
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -487,7 +487,8 @@ class Node:
steps=event.steps,
)
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -501,7 +502,8 @@ class Node:
error=event.error,
)
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
node_id=self._node_id,

View File

@ -19,7 +19,7 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@ -55,7 +55,7 @@ class ToolNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
"""

View File

@ -356,8 +356,8 @@ class WorkflowCycleManager:
workflow_execution: WorkflowExecution,
event: QueueNodeStartedEvent,
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
created_at: Optional[datetime] = None,
error: str | None = None,
created_at: datetime | None = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = naive_utc_now()

View File

@ -557,7 +557,9 @@ class WorkflowService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
def get_default_block_config(
self, node_type: str, filters: Mapping[str, object] | None = None
) -> Mapping[str, object]:
"""
Get default config of node.
:param node_type: node type
@ -568,12 +570,12 @@ class WorkflowService:
# return default block config
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
return {}
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return {}
return default_config