diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 67f16743c3..3c7dcb8d66 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -33,7 +33,13 @@ from core.workflow.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node @@ -93,7 +99,7 @@ class AgentNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: + def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError try: @@ -482,7 +488,7 @@ class AgentNode(Node): node_type: NodeType, node_id: str, node_execution_id: str, - ) -> Generator: + ) -> Generator[NodeEventBase, None, None]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8816e22a85..e5db872e3b 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,8 @@ import logging from abc import abstractmethod -from collections.abc import Callable, Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from collections.abc import Generator, Mapping, Sequence +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any, ClassVar from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom @@ -88,14 +89,14 @@ class Node: def init_node_data(self, data: Mapping[str, Any]) -> None: ... @abstractmethod - def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]": + def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ Run node :return: """ raise NotImplementedError - def run(self) -> "Generator[GraphNodeEventBase, None, None]": + def run(self) -> Generator[GraphNodeEventBase, None, None]: # Generate a single node execution ID to use for all events if not self._node_execution_id: self._node_execution_id = str(uuid4()) @@ -142,8 +143,9 @@ class Node: # Handle event stream for event in result: - if isinstance(event, NodeEventBase): - event = self._convert_node_event_to_graph_node_event(event) + # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase + if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] + event = self._dispatch(event) if not event.in_iteration_id and not event.in_loop_id: event.id = self._node_execution_id @@ -240,7 +242,7 @@ class Node: return False @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return {} @classmethod @@ -261,7 +263,7 @@ class Node: # to BaseNodeData properties in a type-safe way @abstractmethod - def _get_error_strategy(self) -> Optional["ErrorStrategy"]: + def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" ... @@ -276,7 +278,7 @@ class Node: ... @abstractmethod - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: """Get the node description.""" ... @@ -292,7 +294,7 @@ class Node: # Public interface properties that delegate to abstract methods @property - def error_strategy(self) -> Optional["ErrorStrategy"]: + def error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" return self._get_error_strategy() @@ -307,7 +309,7 @@ class Node: return self._get_title() @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """Get the node description.""" return self._get_description() @@ -335,29 +337,15 @@ class Node: start_at=self._start_at, node_run_result=result, ) - raise Exception(f"result status {result.status} not supported") + case _: + raise Exception(f"result status {result.status} not supported") - def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase: - handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = { - StreamChunkEvent: self._handle_stream_chunk_event, - StreamCompletedEvent: self._handle_stream_completed_event, - AgentLogEvent: self._handle_agent_log_event, - LoopStartedEvent: self._handle_loop_started_event, - LoopNextEvent: self._handle_loop_next_event, - LoopSucceededEvent: self._handle_loop_succeeded_event, - LoopFailedEvent: self._handle_loop_failed_event, - IterationStartedEvent: self._handle_iteration_started_event, - IterationNextEvent: self._handle_iteration_next_event, - IterationSucceededEvent: self._handle_iteration_succeeded_event, - IterationFailedEvent: self._handle_iteration_failed_event, - RunRetrieverResourceEvent: self._handle_run_retriever_resource_event, - } - handler = handler_maps.get(type(event)) - if not handler: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - return handler(event) + @singledispatchmethod + def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: + raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: + @_dispatch.register + def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( id=self._node_execution_id, node_id=self._node_id, @@ -367,7 +355,8 @@ class Node: is_final=event.is_final, ) - def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: + @_dispatch.register + def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -386,9 +375,13 @@ class Node: node_run_result=event.node_run_result, error=event.node_run_result.error, ) - raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}") + case _: + raise NotImplementedError( + f"Node {self._node_id} does not support status {event.node_run_result.status}" + ) - def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: + @_dispatch.register + def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( id=self._node_execution_id, node_id=self._node_id, @@ -403,7 +396,8 @@ class Node: metadata=event.metadata, ) - def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: + @_dispatch.register + def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -415,7 +409,8 @@ class Node: predecessor_node_id=event.predecessor_node_id, ) - def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: + @_dispatch.register + def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( id=self._node_execution_id, node_id=self._node_id, @@ -425,7 +420,8 @@ class Node: pre_loop_output=event.pre_loop_output, ) - def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: + @_dispatch.register + def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( id=self._node_execution_id, node_id=self._node_id, @@ -438,7 +434,8 @@ class Node: steps=event.steps, ) - def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: + @_dispatch.register + def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -452,7 +449,8 @@ class Node: error=event.error, ) - def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: + @_dispatch.register + def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -464,7 +462,8 @@ class Node: predecessor_node_id=event.predecessor_node_id, ) - def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: + @_dispatch.register + def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( id=self._node_execution_id, node_id=self._node_id, @@ -474,7 +473,8 @@ class Node: pre_iteration_output=event.pre_iteration_output, ) - def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: + @_dispatch.register + def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( id=self._node_execution_id, node_id=self._node_id, @@ -487,7 +487,8 @@ class Node: steps=event.steps, ) - def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: + @_dispatch.register + def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -501,7 +502,8 @@ class Node: error=event.error, ) - def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: + @_dispatch.register + def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( id=self._node_execution_id, node_id=self._node_id, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c2c9def30c..6829d649d3 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -19,7 +19,7 @@ from core.workflow.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser @@ -55,7 +55,7 @@ class ToolNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node """ diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 5d6362b1c4..cbf4cfd136 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -356,8 +356,8 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, event: QueueNodeStartedEvent, status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, - created_at: Optional[datetime] = None, + error: str | None = None, + created_at: datetime | None = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" now = naive_utc_now() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 593d577f0e..05cd9610ef 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -557,7 +557,9 @@ class WorkflowService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + def get_default_block_config( + self, node_type: str, filters: Mapping[str, object] | None = None + ) -> Mapping[str, object]: """ Get default config of node. :param node_type: node type @@ -568,12 +570,12 @@ class WorkflowService: # return default block config if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: - return None + return {} node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] default_config = node_class.get_default_config(filters=filters) if not default_config: - return None + return {} return default_config