fix: preserve timing metrics in parallel iteration (#33216)

This commit is contained in:
盐粒 Yanli 2026-03-19 18:05:52 +08:00 committed by GitHub
parent 2b8823f38d
commit df0ded210f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 388 additions and 20 deletions

View File

@ -517,7 +517,7 @@ class WorkflowResponseConverter:
snapshot = self._pop_snapshot(event.node_execution_id) snapshot = self._pop_snapshot(event.node_execution_id)
start_at = snapshot.start_at if snapshot else event.start_at start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now() finished_at = event.finished_at or naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds() elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs) inputs, inputs_truncated = self._truncate_mapping(event.inputs)

View File

@ -456,6 +456,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
start_at=event.start_at, start_at=event.start_at,
finished_at=event.finished_at,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs=outputs, outputs=outputs,
@ -471,6 +472,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
start_at=event.start_at, start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs, inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data, process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs, outputs=event.node_run_result.outputs,
@ -487,6 +489,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
start_at=event.start_at, start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs, inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data, process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs, outputs=event.node_run_result.outputs,

View File

@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict)
@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict)
@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict)

View File

@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id) domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED) self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event.finished_at,
)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id) domain_execution = self._get_node_execution(event.id)
@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result, event.node_run_result,
WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.FAILED,
error=event.error, error=event.error,
finished_at=event.finished_at,
) )
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result, event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION, WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error, error=event.error,
finished_at=event.finished_at,
) )
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None: def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
*, *,
error: str | None = None, error: str | None = None,
update_outputs: bool = True, update_outputs: bool = True,
finished_at: datetime | None = None,
) -> None: ) -> None:
finished_at = naive_utc_now() actual_finished_at = finished_at or naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id) snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status domain_execution.status = status
domain_execution.finished_at = finished_at domain_execution.finished_at = actual_finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0) domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
if error: if error:
domain_execution.error = error domain_execution.error = error

View File

@ -159,6 +159,7 @@ class ErrorHandler:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
start_at=event.start_at, start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION, status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs, inputs=event.node_run_result.inputs,
@ -198,6 +199,7 @@ class ErrorHandler:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
start_at=event.start_at, start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION, status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs, inputs=event.node_run_result.inputs,

View File

@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override from typing_extensions import override
from dify_graph.context import IExecutionContext from dify_graph.context import IExecutionContext
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.node import Node
from libs.datetime_utils import naive_utc_now
from .ready_queue import ReadyQueue from .ready_queue import ReadyQueue
@ -65,6 +68,7 @@ class Worker(threading.Thread):
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._layers = layers if layers is not None else [] self._layers = layers if layers is not None else []
self._last_task_time = time.time() self._last_task_time = time.time()
self._current_node_started_at: datetime | None = None
def stop(self) -> None: def stop(self) -> None:
"""Signal the worker to stop processing.""" """Signal the worker to stop processing."""
@ -104,18 +108,15 @@ class Worker(threading.Thread):
self._last_task_time = time.time() self._last_task_time = time.time()
node = self._graph.nodes[node_id] node = self._graph.nodes[node_id]
try: try:
self._current_node_started_at = None
self._execute_node(node) self._execute_node(node)
self._ready_queue.task_done() self._ready_queue.task_done()
except Exception as e: except Exception as e:
error_event = NodeRunFailedEvent( self._event_queue.put(
id=node.execution_id, self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
) )
self._event_queue.put(error_event) finally:
self._current_node_started_at = None
def _execute_node(self, node: Node) -> None: def _execute_node(self, node: Node) -> None:
""" """
@ -136,6 +137,8 @@ class Worker(threading.Thread):
try: try:
node_events = node.run() node_events = node.run()
for event in node_events: for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event) self._event_queue.put(event)
if is_node_result_event(event): if is_node_result_event(event):
result_event = event result_event = event
@ -149,6 +152,8 @@ class Worker(threading.Thread):
try: try:
node_events = node.run() node_events = node.run()
for event in node_events: for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event) self._event_queue.put(event)
if is_node_result_event(event): if is_node_result_event(event):
result_event = event result_event = event
@ -177,3 +182,24 @@ class Worker(threading.Thread):
except Exception: except Exception:
# Silently ignore layer errors to prevent disrupting node execution # Silently ignore layer errors to prevent disrupting node execution
continue continue
def _build_fallback_failure_event(
self, node: Node, error: Exception, *, started_at: datetime | None = None
) -> NodeRunFailedEvent:
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
failure_time = naive_utc_now()
error_message = str(error)
return NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=error_message,
start_at=started_at or failure_time,
finished_at=failure_time,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
error_type=type(error).__name__,
),
)

View File

@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
class NodeRunSucceededEvent(GraphNodeEventBase): class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time") start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunFailedEvent(GraphNodeEventBase): class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error") error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time") start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunExceptionEvent(GraphNodeEventBase): class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error") error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time") start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunRetryEvent(NodeRunStartedEvent): class NodeRunRetryEvent(NodeRunStartedEvent):

View File

@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
error=str(e), error=str(e),
error_type="WorkflowNodeError", error_type="WorkflowNodeError",
) )
finished_at = naive_utc_now()
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
id=self.execution_id, id=self.execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
start_at=self._start_at, start_at=self._start_at,
finished_at=finished_at,
node_run_result=result, node_run_result=result,
error=str(e), error=str(e),
) )
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
return self._node_data return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
finished_at = naive_utc_now()
match result.status: match result.status:
case WorkflowNodeExecutionStatus.FAILED: case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent( return NodeRunFailedEvent(
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id, node_id=self.id,
node_type=self.node_type, node_type=self.node_type,
start_at=self._start_at, start_at=self._start_at,
finished_at=finished_at,
node_run_result=result, node_run_result=result,
error=result.error, error=result.error,
) )
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id, node_id=self.id,
node_type=self.node_type, node_type=self.node_type,
start_at=self._start_at, start_at=self._start_at,
finished_at=finished_at,
node_run_result=result, node_run_result=result,
) )
case _: case _:
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register @_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
finished_at = naive_utc_now()
match event.node_run_result.status: match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED: case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent( return NodeRunSucceededEvent(
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
start_at=self._start_at, start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result, node_run_result=event.node_run_result,
) )
case WorkflowNodeExecutionStatus.FAILED: case WorkflowNodeExecutionStatus.FAILED:
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
start_at=self._start_at, start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result, node_run_result=event.node_run_result,
error=event.node_run_result.error, error=event.node_run_result.error,
) )

View File

@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
future_to_index: dict[ future_to_index: dict[
Future[ Future[
tuple[ tuple[
datetime, float,
list[GraphNodeEventBase], list[GraphNodeEventBase],
object | None, object | None,
dict[str, Variable], dict[str, Variable],
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
try: try:
result = future.result() result = future.result()
( (
iter_start_at, iteration_duration,
events, events,
output_value, output_value,
conversation_snapshot, conversation_snapshot,
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Yield all events from this iteration # Yield all events from this iteration
yield from events yield from events
# Update tokens and timing # The worker computes duration before we replay buffered events here,
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() # so slow downstream consumers don't inflate per-iteration timing.
iter_run_map[str(index)] = iteration_duration
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
index: int, index: int,
item: object, item: object,
execution_context: "IExecutionContext", execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results.""" """Execute a single iteration in parallel mode and return results."""
with execution_context: with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
conversation_snapshot = self._extract_conversation_variable_snapshot( conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool variable_pool=graph_engine.graph_runtime_state.variable_pool
) )
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
return ( return (
iter_start_at, iteration_duration,
events, events,
output_value, output_value,
conversation_snapshot, conversation_snapshot,

View File

@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
assert response.data.process_data == {} assert response.data.process_data == {}
assert response.data.process_data_truncated is False assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_prefers_event_finished_at(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Finished timestamps should come from the event, not delayed queue processing time."""
converter = self.create_workflow_response_converter()
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr(
"core.app.apps.common.workflow_response_converter.naive_utc_now",
lambda: delayed_processing_time,
)
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=BuiltinNodeTypes.CODE,
node_execution_id="node-exec-1",
start_at=start_at,
finished_at=finished_at,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.elapsed_time == 2.0
assert response.data.finished_at == int(finished_at.timestamp())
def test_workflow_node_retry_response_uses_truncated_process_data(self): def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data().""" """Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter() converter = self.create_workflow_response_converter()

View File

@ -0,0 +1,60 @@
from datetime import UTC, datetime
from unittest.mock import Mock
import pytest
from core.app.workflow.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
_NodeRuntimeSnapshot,
)
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
from dify_graph.node_events import NodeRunResult
def _build_layer() -> WorkflowPersistenceLayer:
application_generate_entity = Mock()
application_generate_entity.inputs = {}
return WorkflowPersistenceLayer(
application_generate_entity=application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
version="1",
graph_data={},
),
workflow_execution_repository=Mock(),
workflow_node_execution_repository=Mock(),
)
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
layer = _build_layer()
node_execution = Mock()
node_execution.id = "node-exec-1"
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
node_execution.update_from_mapping = Mock()
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
node_id="node-id",
title="LLM",
predecessor_node_id=None,
iteration_id="iter-1",
loop_id=None,
created_at=node_execution.created_at,
)
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
layer._update_node_execution(
node_execution,
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event_finished_at,
)
assert node_execution.finished_at == event_finished_at
assert node_execution.elapsed_time == 2.0

View File

@ -0,0 +1,145 @@
import queue
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
from dify_graph.graph_engine.worker import Worker
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
worker = Worker(
ready_queue=InMemoryReadyQueue(),
event_queue=queue.Queue(),
graph=MagicMock(),
layers=[],
)
node = SimpleNamespace(
execution_id="exec-1",
id="node-1",
node_type=BuiltinNodeTypes.LLM,
)
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
assert event.start_at == fixed_time
assert event.finished_at == fixed_time
assert event.error == "boom"
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert event.node_run_result.error == "boom"
assert event.node_run_result.error_type == "RuntimeError"
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
failure_time = start_at + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeNode:
execution_id = "exec-1"
id = "node-1"
node_type = BuiltinNodeTypes.LLM
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="LLM",
start_at=start_at,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"node-1": FakeNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["node-1"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 1:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == start_at
assert fallback_event.finished_at == failure_time
assert fallback_event.error == "queue boom"
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
child_start = parent_start + timedelta(seconds=3)
failure_time = parent_start + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeIterationNode:
execution_id = "iteration-exec"
id = "iteration-node"
node_type = BuiltinNodeTypes.ITERATION
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="Iteration",
start_at=parent_start,
)
yield NodeRunStartedEvent(
id="child-exec",
node_id="child-node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=child_start,
in_iteration_id=self.id,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["iteration-node"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 2:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == parent_start
assert fallback_event.finished_at == failure_time

View File

@ -0,0 +1,63 @@
import time
from contextlib import nullcontext
from datetime import UTC, datetime
import pytest
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_events import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.nodes.iteration.iteration_node import IterationNode
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration", "output"],
is_parallel=True,
parallel_nums=2,
error_handle_mode=ErrorHandleMode.TERMINATED,
)
node._capture_execution_context = lambda: nullcontext()
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
return (
0.1 + (index * 0.1),
[
NodeRunSucceededEvent(
id=f"exec-{index}",
node_id=f"llm-{index}",
node_type=BuiltinNodeTypes.LLM,
start_at=datetime.now(UTC).replace(tzinfo=None),
),
],
f"output-{item}",
{},
LLMUsage.empty_usage(),
)
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
outputs: list[object] = []
iter_run_map: dict[str, float] = {}
usage_accumulator = [LLMUsage.empty_usage()]
generator = node._execute_parallel_iterations(
iterator_list_value=["a", "b"],
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
for _ in generator:
# Simulate a slow consumer replaying buffered events.
time.sleep(0.02)
assert outputs == ["output-a", "output-b"]
assert iter_run_map["0"] == pytest.approx(0.1)
assert iter_run_map["1"] == pytest.approx(0.2)