mirror of https://github.com/langgenius/dify.git
fix(graph_engine): error strategy fall. (#26078)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f4522fd695
commit
2e2c87c5a1
|
|
@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
|
||||||
completed: bool = Field(default=False)
|
completed: bool = Field(default=False)
|
||||||
aborted: bool = Field(default=False)
|
aborted: bool = Field(default=False)
|
||||||
error: GraphExecutionErrorState | None = Field(default=None)
|
error: GraphExecutionErrorState | None = Field(default=None)
|
||||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
exceptions_count: int = Field(default=0)
|
||||||
|
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||||
|
|
||||||
|
|
||||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||||
|
|
@ -103,7 +104,8 @@ class GraphExecution:
|
||||||
completed: bool = False
|
completed: bool = False
|
||||||
aborted: bool = False
|
aborted: bool = False
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||||
|
exceptions_count: int = 0
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Mark the graph execution as started."""
|
"""Mark the graph execution as started."""
|
||||||
|
|
@ -172,6 +174,7 @@ class GraphExecution:
|
||||||
completed=self.completed,
|
completed=self.completed,
|
||||||
aborted=self.aborted,
|
aborted=self.aborted,
|
||||||
error=_serialize_error(self.error),
|
error=_serialize_error(self.error),
|
||||||
|
exceptions_count=self.exceptions_count,
|
||||||
node_executions=node_states,
|
node_executions=node_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -195,6 +198,7 @@ class GraphExecution:
|
||||||
self.completed = state.completed
|
self.completed = state.completed
|
||||||
self.aborted = state.aborted
|
self.aborted = state.aborted
|
||||||
self.error = _deserialize_error(state.error)
|
self.error = _deserialize_error(state.error)
|
||||||
|
self.exceptions_count = state.exceptions_count
|
||||||
self.node_executions = {
|
self.node_executions = {
|
||||||
item.node_id: NodeExecution(
|
item.node_id: NodeExecution(
|
||||||
node_id=item.node_id,
|
node_id=item.node_id,
|
||||||
|
|
@ -205,3 +209,7 @@ class GraphExecution:
|
||||||
)
|
)
|
||||||
for item in state.node_executions
|
for item in state.node_executions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def record_node_failure(self) -> None:
|
||||||
|
"""Increment the count of node failures encountered during execution."""
|
||||||
|
self.exceptions_count += 1
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,12 @@ Event handler implementations for different event types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
from typing import TYPE_CHECKING, final
|
from typing import TYPE_CHECKING, final
|
||||||
|
|
||||||
from core.workflow.entities import GraphRuntimeState
|
from core.workflow.entities import GraphRuntimeState
|
||||||
from core.workflow.enums import NodeExecutionType
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
|
|
@ -122,13 +123,15 @@ class EventHandler:
|
||||||
"""
|
"""
|
||||||
# Track execution in domain model
|
# Track execution in domain model
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
|
is_initial_attempt = node_execution.retry_count == 0
|
||||||
node_execution.mark_started(event.id)
|
node_execution.mark_started(event.id)
|
||||||
|
|
||||||
# Track in response coordinator for stream ordering
|
# Track in response coordinator for stream ordering
|
||||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||||
|
|
||||||
# Collect the event
|
# Collect the event only for the first attempt; retries remain silent
|
||||||
self._event_collector.collect(event)
|
if is_initial_attempt:
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||||
|
|
@ -161,7 +164,7 @@ class EventHandler:
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
|
|
||||||
# Store outputs in variable pool
|
# Store outputs in variable pool
|
||||||
self._store_node_outputs(event)
|
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||||
|
|
||||||
# Forward to response coordinator and emit streaming events
|
# Forward to response coordinator and emit streaming events
|
||||||
streaming_events = self._response_coordinator.intercept_event(event)
|
streaming_events = self._response_coordinator.intercept_event(event)
|
||||||
|
|
@ -191,7 +194,7 @@ class EventHandler:
|
||||||
|
|
||||||
# Handle response node outputs
|
# Handle response node outputs
|
||||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||||
self._update_response_outputs(event)
|
self._update_response_outputs(event.node_run_result.outputs)
|
||||||
|
|
||||||
# Collect the event
|
# Collect the event
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
|
|
@ -207,6 +210,7 @@ class EventHandler:
|
||||||
# Update domain model
|
# Update domain model
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_failed(event.error)
|
node_execution.mark_failed(event.error)
|
||||||
|
self._graph_execution.record_node_failure()
|
||||||
|
|
||||||
result = self._error_handler.handle_node_failure(event)
|
result = self._error_handler.handle_node_failure(event)
|
||||||
|
|
||||||
|
|
@ -227,10 +231,40 @@ class EventHandler:
|
||||||
Args:
|
Args:
|
||||||
event: The node exception event
|
event: The node exception event
|
||||||
"""
|
"""
|
||||||
# Node continues via fail-branch, so it's technically "succeeded"
|
# Node continues via fail-branch/default-value, treat as completion
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
|
|
||||||
|
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||||
|
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||||
|
|
||||||
|
node = self._graph.nodes[event.node_id]
|
||||||
|
|
||||||
|
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||||
|
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||||
|
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||||
|
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||||
|
event.node_id, event.node_run_result.edge_source_handle
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||||
|
|
||||||
|
for edge_event in edge_streaming_events:
|
||||||
|
self._event_collector.collect(edge_event)
|
||||||
|
|
||||||
|
for node_id in ready_nodes:
|
||||||
|
self._state_manager.enqueue_node(node_id)
|
||||||
|
self._state_manager.start_execution(node_id)
|
||||||
|
|
||||||
|
# Update response outputs if applicable
|
||||||
|
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||||
|
self._update_response_outputs(event.node_run_result.outputs)
|
||||||
|
|
||||||
|
self._state_manager.finish_execution(event.node_id)
|
||||||
|
|
||||||
|
# Collect the exception event for observers
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: NodeRunRetryEvent) -> None:
|
def _(self, event: NodeRunRetryEvent) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -242,21 +276,31 @@ class EventHandler:
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.increment_retry()
|
node_execution.increment_retry()
|
||||||
|
|
||||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
# Finish the previous attempt before re-queuing the node
|
||||||
|
self._state_manager.finish_execution(event.node_id)
|
||||||
|
|
||||||
|
# Emit retry event for observers
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
|
# Re-queue node for execution
|
||||||
|
self._state_manager.enqueue_node(event.node_id)
|
||||||
|
self._state_manager.start_execution(event.node_id)
|
||||||
|
|
||||||
|
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||||
"""
|
"""
|
||||||
Store node outputs in the variable pool.
|
Store node outputs in the variable pool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The node succeeded event containing outputs
|
event: The node succeeded event containing outputs
|
||||||
"""
|
"""
|
||||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
for variable_name, variable_value in outputs.items():
|
||||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||||
|
|
||||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||||
"""Update response outputs for response nodes."""
|
"""Update response outputs for response nodes."""
|
||||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||||
for key, value in event.node_run_result.outputs.items():
|
for key, value in outputs.items():
|
||||||
if key == "answer":
|
if key == "answer":
|
||||||
existing = self._graph_runtime_state.get_output("answer", "")
|
existing = self._graph_runtime_state.get_output("answer", "")
|
||||||
if existing:
|
if existing:
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from core.workflow.graph_events import (
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
GraphRunAbortedEvent,
|
GraphRunAbortedEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
)
|
)
|
||||||
|
|
@ -260,12 +261,23 @@ class GraphEngine:
|
||||||
if self._graph_execution.error:
|
if self._graph_execution.error:
|
||||||
raise self._graph_execution.error
|
raise self._graph_execution.error
|
||||||
else:
|
else:
|
||||||
yield GraphRunSucceededEvent(
|
outputs = self._graph_runtime_state.outputs
|
||||||
outputs=self._graph_runtime_state.outputs,
|
exceptions_count = self._graph_execution.exceptions_count
|
||||||
)
|
if exceptions_count > 0:
|
||||||
|
yield GraphRunPartialSucceededEvent(
|
||||||
|
exceptions_count=exceptions_count,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield GraphRunSucceededEvent(
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield GraphRunFailedEvent(error=str(e))
|
yield GraphRunFailedEvent(
|
||||||
|
error=str(e),
|
||||||
|
exceptions_count=self._graph_execution.exceptions_count,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunAbortedEvent,
|
GraphRunAbortedEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
|
|
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||||
if self.include_outputs and event.outputs:
|
if self.include_outputs and event.outputs:
|
||||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||||
|
|
||||||
|
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||||
|
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||||
|
if event.exceptions_count > 0:
|
||||||
|
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||||
|
if self.include_outputs and event.outputs:
|
||||||
|
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||||
|
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||||
if event.exceptions_count > 0:
|
if event.exceptions_count > 0:
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ from core.workflow.enums import (
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import (
|
from core.workflow.node_events import (
|
||||||
|
|
@ -456,7 +457,7 @@ class IterationNode(Node):
|
||||||
if isinstance(event, GraphNodeEventBase):
|
if isinstance(event, GraphNodeEventBase):
|
||||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, GraphRunSucceededEvent):
|
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||||
result = variable_pool.get(self._node_data.output_selector)
|
result = variable_pool.get(self._node_data.output_selector)
|
||||||
if result is None:
|
if result is None:
|
||||||
outputs.append(None)
|
outputs.append(None)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
"""Tests for graph engine event handlers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
|
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||||
|
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
|
||||||
|
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||||
|
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||||
|
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||||
|
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||||
|
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||||
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
from core.workflow.nodes.base.entities import RetryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _StubEdgeProcessor:
|
||||||
|
"""Minimal edge processor stub for tests."""
|
||||||
|
|
||||||
|
|
||||||
|
class _StubErrorHandler:
|
||||||
|
"""Minimal error handler stub for tests."""
|
||||||
|
|
||||||
|
|
||||||
|
class _StubNode:
|
||||||
|
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str) -> None:
|
||||||
|
self.id = node_id
|
||||||
|
self.state = NodeState.UNKNOWN
|
||||||
|
self.title = "Stub Node"
|
||||||
|
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||||
|
self.error_strategy = None
|
||||||
|
self.retry_config = RetryConfig()
|
||||||
|
self.retry = False
|
||||||
|
|
||||||
|
|
||||||
|
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||||
|
"""Construct an EventHandler with in-memory dependencies for testing."""
|
||||||
|
|
||||||
|
node = _StubNode(node_id)
|
||||||
|
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
|
||||||
|
|
||||||
|
variable_pool = VariablePool()
|
||||||
|
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||||
|
graph_execution = GraphExecution(workflow_id="test-workflow")
|
||||||
|
|
||||||
|
event_manager = EventManager()
|
||||||
|
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
|
||||||
|
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
|
||||||
|
|
||||||
|
handler = EventHandler(
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=runtime_state,
|
||||||
|
graph_execution=graph_execution,
|
||||||
|
response_coordinator=response_coordinator,
|
||||||
|
event_collector=event_manager,
|
||||||
|
edge_processor=_StubEdgeProcessor(),
|
||||||
|
state_manager=state_manager,
|
||||||
|
error_handler=_StubErrorHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler, event_manager, graph_execution
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_does_not_emit_additional_start_event() -> None:
|
||||||
|
"""Ensure retry attempts do not produce duplicate start events."""
|
||||||
|
|
||||||
|
node_id = "test-node"
|
||||||
|
handler, event_manager, graph_execution = _build_event_handler(node_id)
|
||||||
|
|
||||||
|
execution_id = "exec-1"
|
||||||
|
node_type = NodeType.CODE
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
start_event = NodeRunStartedEvent(
|
||||||
|
id=execution_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_title="Stub Node",
|
||||||
|
start_at=start_time,
|
||||||
|
)
|
||||||
|
handler.dispatch(start_event)
|
||||||
|
|
||||||
|
retry_event = NodeRunRetryEvent(
|
||||||
|
id=execution_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_title="Stub Node",
|
||||||
|
start_at=start_time,
|
||||||
|
error="boom",
|
||||||
|
retry_index=1,
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error="boom",
|
||||||
|
error_type="TestError",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
handler.dispatch(retry_event)
|
||||||
|
|
||||||
|
# Simulate the node starting execution again after retry
|
||||||
|
second_start_event = NodeRunStartedEvent(
|
||||||
|
id=execution_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_title="Stub Node",
|
||||||
|
start_at=start_time,
|
||||||
|
)
|
||||||
|
handler.dispatch(second_start_event)
|
||||||
|
|
||||||
|
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
|
||||||
|
|
||||||
|
node_execution = graph_execution.get_or_create_node_execution(node_id)
|
||||||
|
assert node_execution.retry_count == 1
|
||||||
|
|
@ -10,11 +10,18 @@ import time
|
||||||
from hypothesis import HealthCheck, given, settings
|
from hypothesis import HealthCheck, given, settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from core.workflow.enums import ErrorStrategy
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
from core.workflow.graph_events import (
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
|
GraphRunStartedEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
|
||||||
|
|
||||||
# Import the test framework from the new module
|
# Import the test framework from the new module
|
||||||
|
from .test_mock_config import MockConfigBuilder
|
||||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
|
||||||
else:
|
else:
|
||||||
assert result.event_sequence_match is True
|
assert result.event_sequence_match is True
|
||||||
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||||
|
runner = TableTestRunner()
|
||||||
|
|
||||||
|
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
|
||||||
|
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
|
||||||
|
|
||||||
|
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||||
|
fixture_data=fixture_data,
|
||||||
|
query="hello",
|
||||||
|
use_mock_factory=True,
|
||||||
|
mock_config=mock_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_node = graph.nodes["llm"]
|
||||||
|
base_node_data = llm_node.get_base_node_data()
|
||||||
|
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||||
|
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||||
|
|
||||||
|
engine = GraphEngine(
|
||||||
|
workflow_id="test_workflow",
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
command_channel=InMemoryChannel(),
|
||||||
|
)
|
||||||
|
|
||||||
|
events = list(engine.run())
|
||||||
|
|
||||||
|
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
|
||||||
|
|
||||||
|
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
|
||||||
|
assert partial_event.exceptions_count == 1
|
||||||
|
assert partial_event.outputs.get("answer") == "fallback response"
|
||||||
|
|
||||||
|
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
|
||||||
|
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.skip(
|
|
||||||
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
|
|
||||||
allow_module_level=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_VALUE_EDGE = [
|
|
||||||
{
|
|
||||||
"id": "start-source-node-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "node",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "node-source-answer-target",
|
|
||||||
"source": "node",
|
|
||||||
"target": "answer",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_default_value_partial_success():
|
|
||||||
"""retry default value node with partial success status"""
|
|
||||||
graph_config = {
|
|
||||||
"edges": DEFAULT_VALUE_EDGE,
|
|
||||||
"nodes": [
|
|
||||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
|
||||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
|
||||||
ContinueOnErrorTestHelper.get_http_node(
|
|
||||||
"default-value",
|
|
||||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
|
||||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
|
||||||
events = list(graph_engine.run())
|
|
||||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
|
||||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
|
||||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
|
||||||
assert len(events) == 11
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_failed():
|
|
||||||
"""retry failed with success status"""
|
|
||||||
graph_config = {
|
|
||||||
"edges": DEFAULT_VALUE_EDGE,
|
|
||||||
"nodes": [
|
|
||||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
|
||||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
|
||||||
ContinueOnErrorTestHelper.get_http_node(
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
|
||||||
events = list(graph_engine.run())
|
|
||||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
|
||||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
|
||||||
assert len(events) == 8
|
|
||||||
Loading…
Reference in New Issue