mirror of https://github.com/langgenius/dify.git
refactor(graph_engine): use singledispatch in Node
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f56fccee9d
commit
00a1af8506
|
|
@ -33,7 +33,13 @@ from core.workflow.enums import (
|
|||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
|
@ -93,7 +99,7 @@ class AgentNode(Node):
|
|||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
|
|
@ -482,7 +488,7 @@ class AgentNode(Node):
|
|||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -88,14 +89,14 @@ class Node:
|
|||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
|
|
@ -142,8 +143,9 @@ class Node:
|
|||
|
||||
# Handle event stream
|
||||
for event in result:
|
||||
if isinstance(event, NodeEventBase):
|
||||
event = self._convert_node_event_to_graph_node_event(event)
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event = self._dispatch(event)
|
||||
|
||||
if not event.in_iteration_id and not event.in_loop_id:
|
||||
event.id = self._node_execution_id
|
||||
|
|
@ -240,7 +242,7 @@ class Node:
|
|||
return False
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -261,7 +263,7 @@ class Node:
|
|||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
|
|
@ -276,7 +278,7 @@ class Node:
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
|
|
@ -292,7 +294,7 @@ class Node:
|
|||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional["ErrorStrategy"]:
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
|
|
@ -307,7 +309,7 @@ class Node:
|
|||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
def description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
|
|
@ -335,29 +337,15 @@ class Node:
|
|||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
case _:
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
|
||||
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
|
||||
StreamChunkEvent: self._handle_stream_chunk_event,
|
||||
StreamCompletedEvent: self._handle_stream_completed_event,
|
||||
AgentLogEvent: self._handle_agent_log_event,
|
||||
LoopStartedEvent: self._handle_loop_started_event,
|
||||
LoopNextEvent: self._handle_loop_next_event,
|
||||
LoopSucceededEvent: self._handle_loop_succeeded_event,
|
||||
LoopFailedEvent: self._handle_loop_failed_event,
|
||||
IterationStartedEvent: self._handle_iteration_started_event,
|
||||
IterationNextEvent: self._handle_iteration_next_event,
|
||||
IterationSucceededEvent: self._handle_iteration_succeeded_event,
|
||||
IterationFailedEvent: self._handle_iteration_failed_event,
|
||||
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
|
||||
}
|
||||
handler = handler_maps.get(type(event))
|
||||
if not handler:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
return handler(event)
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
|
||||
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -367,7 +355,8 @@ class Node:
|
|||
is_final=event.is_final,
|
||||
)
|
||||
|
||||
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
|
|
@ -386,9 +375,13 @@ class Node:
|
|||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -403,7 +396,8 @@ class Node:
|
|||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -415,7 +409,8 @@ class Node:
|
|||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -425,7 +420,8 @@ class Node:
|
|||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
||||
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -438,7 +434,8 @@ class Node:
|
|||
steps=event.steps,
|
||||
)
|
||||
|
||||
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -452,7 +449,8 @@ class Node:
|
|||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -464,7 +462,8 @@ class Node:
|
|||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -474,7 +473,8 @@ class Node:
|
|||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
||||
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -487,7 +487,8 @@ class Node:
|
|||
steps=event.steps,
|
||||
)
|
||||
|
||||
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -501,7 +502,8 @@ class Node:
|
|||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from core.workflow.enums import (
|
|||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
|
@ -55,7 +55,7 @@ class ToolNode(Node):
|
|||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -356,8 +356,8 @@ class WorkflowCycleManager:
|
|||
workflow_execution: WorkflowExecution,
|
||||
event: QueueNodeStartedEvent,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
error: str | None = None,
|
||||
created_at: datetime | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a node execution from an event."""
|
||||
now = naive_utc_now()
|
||||
|
|
|
|||
|
|
@ -557,7 +557,9 @@ class WorkflowService:
|
|||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
def get_default_block_config(
|
||||
self, node_type: str, filters: Mapping[str, object] | None = None
|
||||
) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
|
|
@ -568,12 +570,12 @@ class WorkflowService:
|
|||
|
||||
# return default block config
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
return None
|
||||
return {}
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
return {}
|
||||
|
||||
return default_config
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue