mirror of https://github.com/langgenius/dify.git
fix(graph_engine): error strategy fall. (#26078)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f4522fd695
commit
2e2c87c5a1
|
|
@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
|
|||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
|
|
@ -103,7 +104,8 @@ class GraphExecution:
|
|||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
|
|
@ -172,6 +174,7 @@ class GraphExecution:
|
|||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
|
|
@ -195,6 +198,7 @@ class GraphExecution:
|
|||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
|
|
@ -205,3 +209,7 @@ class GraphExecution:
|
|||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ Event handler implementations for different event types.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
|
|
@ -122,13 +123,15 @@ class EventHandler:
|
|||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
|
|
@ -161,7 +164,7 @@ class EventHandler:
|
|||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
|
|
@ -191,7 +194,7 @@ class EventHandler:
|
|||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
|
@ -207,6 +210,7 @@ class EventHandler:
|
|||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
|
|
@ -227,10 +231,40 @@ class EventHandler:
|
|||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
|
|
@ -242,21 +276,31 @@ class EventHandler:
|
|||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.workflow.graph_events import (
|
|||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
|
@ -260,12 +261,23 @@ class GraphEngine:
|
|||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.workflow.graph_events import (
|
|||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
|
|
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.workflow.enums import (
|
|||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
|
|
@ -456,7 +457,7 @@ class IterationNode(Node):
|
|||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for graph engine event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
|
||||
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
"""Minimal edge processor stub for tests."""
|
||||
|
||||
|
||||
class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
def __init__(self, node_id: str) -> None:
|
||||
self.id = node_id
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "Stub Node"
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
"""Construct an EventHandler with in-memory dependencies for testing."""
|
||||
|
||||
node = _StubNode(node_id)
|
||||
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_execution = GraphExecution(workflow_id="test-workflow")
|
||||
|
||||
event_manager = EventManager()
|
||||
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
|
||||
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
|
||||
|
||||
handler = EventHandler(
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
graph_execution=graph_execution,
|
||||
response_coordinator=response_coordinator,
|
||||
event_collector=event_manager,
|
||||
edge_processor=_StubEdgeProcessor(),
|
||||
state_manager=state_manager,
|
||||
error_handler=_StubErrorHandler(),
|
||||
)
|
||||
|
||||
return handler, event_manager, graph_execution
|
||||
|
||||
|
||||
def test_retry_does_not_emit_additional_start_event() -> None:
|
||||
"""Ensure retry attempts do not produce duplicate start events."""
|
||||
|
||||
node_id = "test-node"
|
||||
handler, event_manager, graph_execution = _build_event_handler(node_id)
|
||||
|
||||
execution_id = "exec-1"
|
||||
node_type = NodeType.CODE
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(start_event)
|
||||
|
||||
retry_event = NodeRunRetryEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
error="boom",
|
||||
retry_index=1,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="boom",
|
||||
error_type="TestError",
|
||||
),
|
||||
)
|
||||
handler.dispatch(retry_event)
|
||||
|
||||
# Simulate the node starting execution again after retry
|
||||
second_start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(second_start_event)
|
||||
|
||||
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
|
||||
|
||||
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
|
||||
|
||||
node_execution = graph_execution.get_or_create_node_execution(node_id)
|
||||
assert node_execution.retry_count == 1
|
||||
|
|
@ -10,11 +10,18 @@ import time
|
|||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
|
||||
|
||||
# Import the test framework from the new module
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||
|
||||
|
||||
|
|
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
|
|||
else:
|
||||
assert result.event_sequence_match is True
|
||||
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
||||
|
||||
|
||||
def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||
runner = TableTestRunner()
|
||||
|
||||
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
|
||||
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
|
||||
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
query="hello",
|
||||
use_mock_factory=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
|
||||
|
||||
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
|
||||
assert partial_event.exceptions_count == 1
|
||||
assert partial_event.outputs.get("answer") == "fallback response"
|
||||
|
||||
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
Loading…
Reference in New Issue