fix(graph_engine): error strategy fall. (#26078)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-23 01:51:43 +08:00 committed by GitHub
parent f4522fd695
commit 2e2c87c5a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 255 additions and 84 deletions

View File

@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False) completed: bool = Field(default=False)
aborted: bool = Field(default=False) aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list) exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@ -103,7 +104,8 @@ class GraphExecution:
completed: bool = False completed: bool = False
aborted: bool = False aborted: bool = False
error: Exception | None = None error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict) node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None: def start(self) -> None:
"""Mark the graph execution as started.""" """Mark the graph execution as started."""
@ -172,6 +174,7 @@ class GraphExecution:
completed=self.completed, completed=self.completed,
aborted=self.aborted, aborted=self.aborted,
error=_serialize_error(self.error), error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states, node_executions=node_states,
) )
@ -195,6 +198,7 @@ class GraphExecution:
self.completed = state.completed self.completed = state.completed
self.aborted = state.aborted self.aborted = state.aborted
self.error = _deserialize_error(state.error) self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = { self.node_executions = {
item.node_id: NodeExecution( item.node_id: NodeExecution(
node_id=item.node_id, node_id=item.node_id,
@ -205,3 +209,7 @@ class GraphExecution:
) )
for item in state.node_executions for item in state.node_executions
} }
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@ -3,11 +3,12 @@ Event handler implementations for different event types.
""" """
import logging import logging
from collections.abc import Mapping
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import TYPE_CHECKING, final from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
@ -122,13 +123,15 @@ class EventHandler:
""" """
# Track execution in domain model # Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id) node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering # Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id) self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event # Collect the event only for the first attempt; retries remain silent
self._event_collector.collect(event) if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register @_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None: def _(self, event: NodeRunStreamChunkEvent) -> None:
@ -161,7 +164,7 @@ class EventHandler:
node_execution.mark_taken() node_execution.mark_taken()
# Store outputs in variable pool # Store outputs in variable pool
self._store_node_outputs(event) self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events # Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event) streaming_events = self._response_coordinator.intercept_event(event)
@ -191,7 +194,7 @@ class EventHandler:
# Handle response node outputs # Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE: if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event) self._update_response_outputs(event.node_run_result.outputs)
# Collect the event # Collect the event
self._event_collector.collect(event) self._event_collector.collect(event)
@ -207,6 +210,7 @@ class EventHandler:
# Update domain model # Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error) node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event) result = self._error_handler.handle_node_failure(event)
@ -227,10 +231,40 @@ class EventHandler:
Args: Args:
event: The node exception event event: The node exception event
""" """
# Node continues via fail-branch, so it's technically "succeeded" # Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken() node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register @_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None: def _(self, event: NodeRunRetryEvent) -> None:
""" """
@ -242,21 +276,31 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry() node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: # Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
""" """
Store node outputs in the variable pool. Store node outputs in the variable pool.
Args: Args:
event: The node succeeded event containing outputs event: The node succeeded event containing outputs
""" """
for variable_name, variable_value in event.node_run_result.outputs.items(): for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes.""" """Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs # TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state. # in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items(): for key, value in outputs.items():
if key == "answer": if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "") existing = self._graph_runtime_state.get_output("answer", "")
if existing: if existing:

View File

@ -23,6 +23,7 @@ from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
@ -260,12 +261,23 @@ class GraphEngine:
if self._graph_execution.error: if self._graph_execution.error:
raise self._graph_execution.error raise self._graph_execution.error
else: else:
yield GraphRunSucceededEvent( outputs = self._graph_runtime_state.outputs
outputs=self._graph_runtime_state.outputs, exceptions_count = self._graph_execution.exceptions_count
) if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)
except Exception as e: except Exception as e:
yield GraphRunFailedEvent(error=str(e)) yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise raise
finally: finally:

View File

@ -15,6 +15,7 @@ from core.workflow.graph_events import (
GraphEngineEvent, GraphEngineEvent,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs: if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error) self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0: if event.exceptions_count > 0:

View File

@ -19,6 +19,7 @@ from core.workflow.enums import (
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
from core.workflow.node_events import ( from core.workflow.node_events import (
@ -456,7 +457,7 @@ class IterationNode(Node):
if isinstance(event, GraphNodeEventBase): if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index) self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event yield event
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector) result = variable_pool.get(self._node_data.output_selector)
if result is None: if result is None:
outputs.append(None) outputs.append(None)

View File

@ -0,0 +1,120 @@
"""Tests for graph engine event handlers."""
from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""
class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""
node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")
event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)
return handler, event_manager, graph_execution
def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""
node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)
execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()
start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)
retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)
# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

View File

@ -10,11 +10,18 @@ import time
from hypothesis import HealthCheck, given, settings from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
# Import the test framework from the new module # Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
else: else:
assert result.event_sequence_match is True assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)
llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)
events = list(engine.run())
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

View File

@ -1,65 +0,0 @@
import pytest
pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8