diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 65b8af67065..b30603b860d 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -7,7 +7,13 @@ from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunSucceededEvent, +) from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity @@ -23,6 +29,7 @@ class TriggerPostLayer(GraphEngineLayer): _STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = { GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED, GraphRunFailedEvent: WorkflowTriggerStatus.FAILED, + GraphRunAbortedEvent: WorkflowTriggerStatus.FAILED, GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED, } @@ -73,6 +80,8 @@ class TriggerPostLayer(GraphEngineLayer): trigger_log.status = self._STATUS_MAP[type(event)] trigger_log.workflow_run_id = workflow_run_id trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode() + if isinstance(event, GraphRunAbortedEvent): + trigger_log.error = event.reason or "Workflow execution aborted" if trigger_log.elapsed_time is None: trigger_log.elapsed_time = elapsed_time diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index 320a3bc42cc..f82cf201422 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -4,7 +4,11 @@ from unittest.mock import Mock, patch from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables -from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.graph_events import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunSucceededEvent, +) from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus @@ -59,6 +63,57 @@ class TestTriggerPostLayer: repo.update.assert_called_once_with(trigger_log) session.commit.assert_called_once() + def test_on_event_updates_trigger_log_for_aborted_event(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + error=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"partial": "ok"}, + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), + total_tokens=7, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunAbortedEvent(reason="timeout")) + + assert trigger_log.status == WorkflowTriggerStatus.FAILED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.error == "timeout" + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 7 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + def test_on_event_handles_missing_trigger_log(self): runtime_state = SimpleNamespace( outputs={},