diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 3c3b04fde0c..a383839bd05 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -19,11 +19,16 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) +from core.app.entities.task_entities import WorkflowFinishStreamResponse, WorkflowStartStreamResponse from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus from graphon.runtime import GraphRuntimeState +from libs.datetime_utils import naive_utc_now from libs.flask_utils import set_login_user +from libs.helper import to_timestamp from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, Conversation, EndUser, Message @@ -173,14 +178,24 @@ class _AppRunner: ) except Exception as exc: if exec_params.streaming: - _publish_error_event(exc, exec_params.workflow_run_id, exec_params.app_mode) + _publish_failed_workflow_terminal_events( + exc=exc, + exec_params=exec_params, + ) raise if not exec_params.streaming: return response assert isinstance(response, Generator) - _publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode) + _publish_streaming_response( + response, + exec_params.workflow_run_id, + exec_params.app_mode, + exec_params.workflow_id, + exec_params.args.get("inputs", {}), + WorkflowStartReason.INITIAL, + ) def _run_app( self, @@ -246,29 +261,197 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun return session.get(EndUser, workflow_run.created_by) -def _publish_error_event(exc: Exception, workflow_run_id: str, app_mode: AppMode) -> None: - topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) - payload = json.dumps({"event": "error", "message": str(exc), "status": 500}) - topic.publish(payload.encode()) +def _publish_failed_workflow_terminal_events(exc: Exception, exec_params: AppExecutionParams) -> None: + """Publish synthetic workflow lifecycle events for pre-runtime failures. + + Early failures can happen before the app generator creates a task entity or + emits any workflow queue events. In that window SSE consumers still need a + normal terminal event to close their state machines, so we synthesize a + minimal `workflow_started -> workflow_finished(failed)` sequence here. + + `workflow_run_id` is reused as a synthetic `task_id` because no application + task id exists yet on this failure path. + """ + timestamp = to_timestamp(naive_utc_now()) + assert timestamp is not None + + topic = MessageBasedAppGenerator.get_response_topic(exec_params.app_mode, exec_params.workflow_run_id) + started_payload = WorkflowStartStreamResponse( + task_id=exec_params.workflow_run_id, + workflow_run_id=exec_params.workflow_run_id, + data=WorkflowStartStreamResponse.Data( + id=exec_params.workflow_run_id, + workflow_id=exec_params.workflow_id, + inputs=exec_params.args.get("inputs", {}), + created_at=timestamp, + reason=WorkflowStartReason.INITIAL, + ), + ) + topic.publish(json.dumps(started_payload.model_dump(mode="json"), ensure_ascii=False).encode()) + + finished_payload = WorkflowFinishStreamResponse( + task_id=exec_params.workflow_run_id, + workflow_run_id=exec_params.workflow_run_id, + data=WorkflowFinishStreamResponse.Data( + id=exec_params.workflow_run_id, + workflow_id=exec_params.workflow_id, + status=WorkflowExecutionStatus.FAILED, + outputs=None, + error=str(exc), + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by={}, + created_at=timestamp, + finished_at=timestamp, + exceptions_count=1, + files=[], + ), + ) + topic.publish(json.dumps(finished_payload.model_dump(mode="json"), ensure_ascii=False).encode()) + + +def _get_event_name(event: str | Mapping[str, Any] | BaseModel) -> str | None: + if isinstance(event, BaseModel): + # Temporary compatibility for legacy BaseModel stream events; remove after confirming generators always emit + # str / Mapping responses. + event_name = getattr(event, "event", None) + elif isinstance(event, Mapping): + event_name = event.get("event") + else: + return None + + if event_name is None: + return None + return str(event_name) + + +def _get_task_id(event: str | Mapping[str, Any] | BaseModel) -> str | None: + if isinstance(event, BaseModel): + # Temporary compatibility for legacy BaseModel stream events; remove after confirming generators always emit + # str / Mapping responses. + task_id = getattr(event, "task_id", None) + elif isinstance(event, Mapping): + task_id = event.get("task_id") + else: + return None + + return task_id if isinstance(task_id, str) and task_id else None def _publish_streaming_response( response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], - workflow_run_id: str, + workflow_run_id: str | uuid.UUID, app_mode: AppMode, + workflow_id: str, + inputs: Mapping[str, Any], + started_reason: WorkflowStartReason, ) -> None: - topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) - for event in response_stream: - try: - if isinstance(event, BaseModel): - payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) - else: - payload = json.dumps(event, ensure_ascii=False, default=str) - except (TypeError, ValueError): - logger.exception("error while encoding event") - continue + """Publish workflow stream events and close broken streams with a failed terminal event. - topic.publish(payload.encode()) + `_AppRunner.run()` only handles failures before the generator is returned. + Once we start iterating the runtime stream, this helper becomes the last + place that can guarantee SSE consumers eventually see a terminal workflow + lifecycle event. + """ + normalized_workflow_run_id = str(workflow_run_id) + + def _publish_failed_terminal_event(error_message: str, task_id: str, publish_started: bool) -> None: + timestamp = to_timestamp(naive_utc_now()) + assert timestamp is not None + + if publish_started: + started_payload = WorkflowStartStreamResponse( + task_id=task_id, + workflow_run_id=normalized_workflow_run_id, + data=WorkflowStartStreamResponse.Data( + id=normalized_workflow_run_id, + workflow_id=workflow_id, + inputs=inputs, + created_at=timestamp, + reason=started_reason, + ), + ) + topic.publish( + json.dumps( + started_payload.model_dump(mode="json", fallback=str), + ensure_ascii=False, + ).encode() + ) + + finished_payload = WorkflowFinishStreamResponse( + task_id=task_id, + workflow_run_id=normalized_workflow_run_id, + data=WorkflowFinishStreamResponse.Data( + id=normalized_workflow_run_id, + workflow_id=workflow_id, + status=WorkflowExecutionStatus.FAILED, + outputs=None, + error=error_message, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by={}, + created_at=timestamp, + finished_at=timestamp, + exceptions_count=1, + files=[], + ), + ) + topic.publish(json.dumps(finished_payload.model_dump(mode="json"), ensure_ascii=False).encode()) + + terminal_events = {"workflow_finished", "workflow_paused"} + unexpected_stream_end_message = "Workflow stream ended without a terminal event" + topic = MessageBasedAppGenerator.get_response_topic(app_mode, normalized_workflow_run_id) + started_published = False + terminal_published = False + last_task_id = normalized_workflow_run_id + + try: + for event in response_stream: + event_name = _get_event_name(event) + task_id = _get_task_id(event) + if task_id is not None: + last_task_id = task_id + + try: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): + logger.exception("error while encoding event") + continue + + topic.publish(payload.encode()) + + if event_name == "workflow_started": + started_published = True + elif event_name in terminal_events: + terminal_published = True + except Exception as exc: + if not terminal_published: + logger.exception( + "Workflow stream for run %s failed before terminal event; publishing fallback terminal event", + normalized_workflow_run_id, + ) + _publish_failed_terminal_event( + error_message=str(exc) or exc.__class__.__name__, + task_id=last_task_id, + publish_started=not started_published, + ) + raise + + if not terminal_published: + logger.warning( + "Workflow stream for run %s ended without a terminal event; publishing fallback terminal event", + normalized_workflow_run_id, + ) + _publish_failed_terminal_event( + error_message=unexpected_stream_end_message, + task_id=last_task_id, + publish_started=not started_published, + ) @shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE) @@ -454,7 +637,14 @@ def _resume_advanced_chat( raise assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) + _publish_streaming_response( + response, + workflow_run_id, + AppMode.ADVANCED_CHAT, + workflow.id, + generate_entity.inputs, + WorkflowStartReason.RESUMPTION, + ) def _resume_workflow( @@ -509,7 +699,14 @@ def _resume_workflow( raise assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) + _publish_streaming_response( + response, + workflow_run_id, + AppMode.WORKFLOW, + workflow.id, + generate_entity.inputs, + WorkflowStartReason.RESUMPTION, + ) try: workflow_run_repo.delete_workflow_pause(pause_entity) diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py index 2544c9d61a9..40965096b39 100644 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -1,18 +1,26 @@ from __future__ import annotations import json +import logging import uuid +from contextlib import nullcontext from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from pydantic import BaseModel from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole from models.model import App, AppMode, Conversation from models.workflow import Workflow, WorkflowRun from repositories.sqlalchemy_api_workflow_run_repository import _WorkflowRunError +from tasks.app_generate import workflow_execute_task as workflow_execute_task_module from tasks.app_generate.workflow_execute_task import ( + AppExecutionParams, + _AppRunner, _publish_streaming_response, _resume_advanced_chat, _resume_app_execution, @@ -31,6 +39,11 @@ class _FakeSessionContext: return False +class _StreamEventModel(BaseModel): + event: object | None = None + task_id: object | None = None + + def _build_advanced_chat_generate_entity(conversation_id: str | None) -> AdvancedChatAppGenerateEntity: return AdvancedChatAppGenerateEntity( task_id="task-id", @@ -60,6 +73,46 @@ def _single_event_generator(payload): yield payload +def _decode_published_payload(payload: bytes) -> dict[str, object] | str: + return json.loads(payload.decode()) + + +def _published_payloads(topic: MagicMock) -> list[dict[str, object] | str]: + return [_decode_published_payload(call.args[0]) for call in topic.publish.call_args_list] + + +@pytest.mark.parametrize( + ("event", "expected"), + [ + ({"event": "workflow_started"}, "workflow_started"), + ({"event": 123}, "123"), + (_StreamEventModel(event="workflow_started"), "workflow_started"), + (_StreamEventModel(event=123), "123"), + ({}, None), + (_StreamEventModel(), None), + ("workflow_started", None), + ], +) +def test_get_event_name(event: object, expected: str | None): + assert workflow_execute_task_module._get_event_name(event) == expected + + +@pytest.mark.parametrize( + ("event", "expected"), + [ + ({"task_id": "task-id"}, "task-id"), + (_StreamEventModel(task_id="task-id"), "task-id"), + ({"task_id": 123}, None), + (_StreamEventModel(task_id=123), None), + ({"task_id": ""}, None), + (_StreamEventModel(), None), + ("task-id", None), + ], +) +def test_get_task_id(event: object, expected: str | None): + assert workflow_execute_task_module._get_task_id(event) == expected + + @pytest.fixture def mock_topic(monkeypatch: pytest.MonkeyPatch) -> MagicMock: topic = MagicMock() @@ -72,21 +125,413 @@ def mock_topic(monkeypatch: pytest.MonkeyPatch) -> MagicMock: def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "foo"}, "ping"]) + response_stream = iter( + [ + {"event": "workflow_started", "task_id": "task-id"}, + {"event": "workflow_finished", "task_id": "task-id", "data": {"status": "succeeded"}}, + ] + ) - _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) + _publish_streaming_response( + response_stream, + workflow_run_id, + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) - payloads = [call.args[0] for call in mock_topic.publish.call_args_list] - assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] + payloads = _published_payloads(mock_topic) + assert [payload["event"] for payload in payloads] == ["workflow_started", "workflow_finished"] def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "bar"}]) + response_stream = iter([{"event": "workflow_paused", "task_id": "task-id"}]) - _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) + _publish_streaming_response( + response_stream, + str(workflow_run_id), + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) - mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) + payloads = _published_payloads(mock_topic) + assert [payload["event"] for payload in payloads] == ["workflow_paused"] + + +def test_publish_streaming_response_publishes_started_then_failed_terminal_when_iteration_raises( + mock_topic: MagicMock, +): + def _response_stream(): + if False: + yield None + raise RuntimeError("stream exploded") + + with pytest.raises(RuntimeError, match="stream exploded"): + _publish_streaming_response( + _response_stream(), + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={"foo": "bar"}, + started_reason=WorkflowStartReason.INITIAL, + ) + + payloads = _published_payloads(mock_topic) + assert [payload["event"] for payload in payloads] == ["workflow_started", "workflow_finished"] + assert payloads[0]["data"]["workflow_id"] == "workflow-id" + assert payloads[0]["data"]["inputs"] == {"foo": "bar"} + assert payloads[1]["data"]["status"] == WorkflowExecutionStatus.FAILED + assert payloads[1]["data"]["error"] == "stream exploded" + + +def test_publish_streaming_response_recovers_when_workflow_started_publish_fails_first( + mock_topic: MagicMock, + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.ERROR, logger="tasks.app_generate.workflow_execute_task") + response_stream = iter([{"event": "workflow_started", "task_id": "task-id"}]) + successful_payloads: list[dict[str, object] | str] = [] + started_publish_attempts = 0 + + def _publish(payload: bytes) -> None: + nonlocal started_publish_attempts + + decoded = _decode_published_payload(payload) + if isinstance(decoded, dict) and decoded.get("event") == "workflow_started": + started_publish_attempts += 1 + if started_publish_attempts == 1: + raise RuntimeError("started publish failed") + successful_payloads.append(decoded) + + mock_topic.publish.side_effect = _publish + + with pytest.raises(RuntimeError, match="started publish failed"): + _publish_streaming_response( + response_stream, + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={"file": object()}, + started_reason=WorkflowStartReason.INITIAL, + ) + + assert [payload["event"] for payload in successful_payloads] == ["workflow_started", "workflow_finished"] + assert successful_payloads[0]["task_id"] == "task-id" + assert isinstance(successful_payloads[0]["data"]["inputs"]["file"], str) + assert successful_payloads[1]["task_id"] == "task-id" + assert successful_payloads[1]["data"]["status"] == WorkflowExecutionStatus.FAILED + assert successful_payloads[1]["data"]["error"] == "started publish failed" + assert "workflow-run-id" in caplog.text + assert "publishing fallback terminal event" in caplog.text + + +def test_publish_streaming_response_publishes_failed_terminal_without_duplicate_started_on_publish_error( + mock_topic: MagicMock, + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.ERROR, logger="tasks.app_generate.workflow_execute_task") + response_stream = iter( + [ + { + "event": "workflow_started", + "task_id": "task-id", + "workflow_run_id": "workflow-run-id", + "data": {"id": "workflow-run-id", "workflow_id": "workflow-id", "inputs": {}, "created_at": 1}, + }, + {"event": "node_started", "task_id": "task-id"}, + ] + ) + successful_payloads: list[dict[str, object] | str] = [] + + def _publish(payload: bytes) -> None: + decoded = _decode_published_payload(payload) + if isinstance(decoded, dict) and decoded.get("event") == "node_started": + raise RuntimeError("broker write failed") + successful_payloads.append(decoded) + + mock_topic.publish.side_effect = _publish + + with pytest.raises(RuntimeError, match="broker write failed"): + _publish_streaming_response( + response_stream, + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) + + assert [payload["event"] for payload in successful_payloads] == ["workflow_started", "workflow_finished"] + assert successful_payloads[1]["task_id"] == "task-id" + assert successful_payloads[1]["data"]["status"] == WorkflowExecutionStatus.FAILED + assert successful_payloads[1]["data"]["error"] == "broker write failed" + assert "workflow-run-id" in caplog.text + assert "publishing fallback terminal event" in caplog.text + + +def test_publish_streaming_response_recovers_when_workflow_finished_publish_fails_first( + mock_topic: MagicMock, + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.ERROR, logger="tasks.app_generate.workflow_execute_task") + response_stream = iter( + [ + {"event": "workflow_started", "task_id": "task-id"}, + {"event": "workflow_finished", "task_id": "task-id", "data": {"status": "succeeded"}}, + ] + ) + successful_payloads: list[dict[str, object] | str] = [] + finished_publish_attempts = 0 + + def _publish(payload: bytes) -> None: + nonlocal finished_publish_attempts + + decoded = _decode_published_payload(payload) + if isinstance(decoded, dict) and decoded.get("event") == "workflow_finished": + finished_publish_attempts += 1 + if finished_publish_attempts == 1: + raise RuntimeError("finished publish failed") + successful_payloads.append(decoded) + + mock_topic.publish.side_effect = _publish + + with pytest.raises(RuntimeError, match="finished publish failed"): + _publish_streaming_response( + response_stream, + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) + + assert [payload["event"] for payload in successful_payloads] == ["workflow_started", "workflow_finished"] + assert successful_payloads[1]["task_id"] == "task-id" + assert successful_payloads[1]["data"]["status"] == WorkflowExecutionStatus.FAILED + assert successful_payloads[1]["data"]["error"] == "finished publish failed" + assert "workflow-run-id" in caplog.text + assert "publishing fallback terminal event" in caplog.text + + +def test_publish_streaming_response_publishes_failed_terminal_on_exhaustion_without_terminal_event( + mock_topic: MagicMock, + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.WARNING, logger="tasks.app_generate.workflow_execute_task") + response_stream = iter( + [ + { + "event": "workflow_started", + "task_id": "task-id", + "workflow_run_id": "workflow-run-id", + "data": {"id": "workflow-run-id", "workflow_id": "workflow-id", "inputs": {}, "created_at": 1}, + } + ] + ) + + _publish_streaming_response( + response_stream, + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) + + payloads = _published_payloads(mock_topic) + assert [payload["event"] for payload in payloads] == ["workflow_started", "workflow_finished"] + assert payloads[1]["task_id"] == "task-id" + assert payloads[1]["data"]["status"] == WorkflowExecutionStatus.FAILED + assert payloads[1]["data"]["error"] == "Workflow stream ended without a terminal event" + assert "workflow-run-id" in caplog.text + assert "ended without a terminal event" in caplog.text + + +def test_publish_streaming_response_does_not_publish_synthetic_failure_after_terminal_event(mock_topic: MagicMock): + response_stream = iter( + [ + { + "event": "workflow_started", + "task_id": "task-id", + "workflow_run_id": "workflow-run-id", + "data": {"id": "workflow-run-id", "workflow_id": "workflow-id", "inputs": {}, "created_at": 1}, + }, + { + "event": "workflow_finished", + "task_id": "task-id", + "workflow_run_id": "workflow-run-id", + "data": { + "id": "workflow-run-id", + "workflow_id": "workflow-id", + "status": WorkflowExecutionStatus.SUCCEEDED, + "outputs": {}, + "error": None, + "elapsed_time": 0.1, + "total_tokens": 1, + "total_steps": 1, + "created_by": {}, + "created_at": 1, + "finished_at": 2, + "exceptions_count": 0, + "files": [], + }, + }, + ] + ) + + _publish_streaming_response( + response_stream, + "workflow-run-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + inputs={}, + started_reason=WorkflowStartReason.INITIAL, + ) + + payloads = _published_payloads(mock_topic) + assert [payload["event"] for payload in payloads] == ["workflow_started", "workflow_finished"] + + +def test_app_runner_streaming_failure_publishes_started_then_failed_workflow_finished( + mock_topic: MagicMock, monkeypatch +): + exec_params = AppExecutionParams( + app_id="app-id", + workflow_id="workflow-id", + tenant_id="tenant-id", + app_mode=AppMode.ADVANCED_CHAT, + user={"TYPE": "account", "user_id": "user-id"}, + args={"inputs": {}, "query": "test"}, + invoke_from=InvokeFrom.EXPLORE, + streaming=True, + workflow_run_id="workflow-run-id", + ) + runner = _AppRunner(session_factory=MagicMock(), exec_params=exec_params) + + workflow = SimpleNamespace(id="workflow-id", app_id="app-id", created_by="workflow-owner") + app = SimpleNamespace(id="app-id") + fake_session = MagicMock() + fake_session.get.side_effect = [workflow, app] + + monkeypatch.setattr(runner, "_session", lambda: nullcontext(fake_session)) + monkeypatch.setattr(runner, "_resolve_user", lambda: MagicMock()) + monkeypatch.setattr(runner, "_setup_flask_context", lambda _user: nullcontext()) + monkeypatch.setattr(runner, "_run_app", lambda **_kwargs: (_ for _ in ()).throw(ValueError("Invalid upload file"))) + + with pytest.raises(ValueError, match="Invalid upload file"): + runner.run() + + assert mock_topic.publish.call_count == 2 + started_payload = json.loads(mock_topic.publish.call_args_list[0].args[0].decode()) + assert started_payload["event"] == "workflow_started" + assert started_payload["workflow_run_id"] == "workflow-run-id" + assert started_payload["task_id"] == "workflow-run-id" + assert started_payload["data"]["id"] == "workflow-run-id" + assert started_payload["data"]["workflow_id"] == "workflow-id" + assert started_payload["data"]["reason"] == "initial" + + finished_payload = json.loads(mock_topic.publish.call_args_list[1].args[0].decode()) + assert finished_payload["event"] == "workflow_finished" + assert finished_payload["workflow_run_id"] == "workflow-run-id" + assert finished_payload["task_id"] == "workflow-run-id" + assert finished_payload["data"]["id"] == "workflow-run-id" + assert finished_payload["data"]["workflow_id"] == "workflow-id" + assert finished_payload["data"]["status"] == WorkflowExecutionStatus.FAILED + assert finished_payload["data"]["error"] == "Invalid upload file" + assert finished_payload["data"]["outputs"] is None + assert finished_payload["data"]["total_tokens"] == 0 + assert finished_payload["data"]["total_steps"] == 0 + assert finished_payload["data"]["exceptions_count"] == 1 + assert finished_payload["data"]["created_by"] == {} + assert finished_payload["data"]["created_at"] == finished_payload["data"]["finished_at"] + assert finished_payload["data"]["files"] == [] + + +def test_app_runner_streaming_failure_keeps_existing_pre_runtime_helper_behavior( + mock_topic: MagicMock, + monkeypatch: pytest.MonkeyPatch, +): + exec_params = AppExecutionParams( + app_id="app-id", + workflow_id="workflow-id", + tenant_id="tenant-id", + app_mode=AppMode.ADVANCED_CHAT, + user={"TYPE": "account", "user_id": "user-id"}, + args={"inputs": {}, "query": "test"}, + invoke_from=InvokeFrom.EXPLORE, + streaming=True, + workflow_run_id="workflow-run-id", + ) + runner = _AppRunner(session_factory=MagicMock(), exec_params=exec_params) + + workflow = SimpleNamespace(id="workflow-id", app_id="app-id", created_by="workflow-owner") + app = SimpleNamespace(id="app-id") + fake_session = MagicMock() + fake_session.get.side_effect = [workflow, app] + + monkeypatch.setattr(runner, "_session", lambda: nullcontext(fake_session)) + monkeypatch.setattr(runner, "_resolve_user", lambda: MagicMock()) + monkeypatch.setattr(runner, "_setup_flask_context", lambda _user: nullcontext()) + monkeypatch.setattr(runner, "_run_app", lambda **_kwargs: (_ for _ in ()).throw(ValueError("Invalid upload file"))) + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda value: (_ for _ in ()).throw(AssertionError("pre-runtime helper should not normalize inputs")), + ) + + with pytest.raises(ValueError, match="Invalid upload file"): + runner.run() + + payloads = _published_payloads(mock_topic) + assert payloads[0]["data"]["inputs"] == {} + assert payloads[0]["data"]["reason"] == WorkflowStartReason.INITIAL + + +def test_app_runner_streaming_success_calls_publish_streaming_response_with_full_signature( + monkeypatch: pytest.MonkeyPatch, +): + exec_params = AppExecutionParams( + app_id="app-id", + workflow_id="workflow-id", + tenant_id="tenant-id", + app_mode=AppMode.ADVANCED_CHAT, + user={"TYPE": "account", "user_id": "user-id"}, + args={"inputs": {"foo": "bar"}, "query": "test"}, + invoke_from=InvokeFrom.EXPLORE, + streaming=True, + workflow_run_id="workflow-run-id", + ) + runner = _AppRunner(session_factory=MagicMock(), exec_params=exec_params) + + workflow = SimpleNamespace(id="workflow-id", app_id="app-id", created_by="workflow-owner") + app = SimpleNamespace(id="app-id") + fake_session = MagicMock() + fake_session.get.side_effect = [workflow, app] + response_stream = _single_event_generator({"event": "message"}) + publish_streaming_response = MagicMock() + + monkeypatch.setattr(runner, "_session", lambda: nullcontext(fake_session)) + monkeypatch.setattr(runner, "_resolve_user", lambda: MagicMock()) + monkeypatch.setattr(runner, "_setup_flask_context", lambda _user: nullcontext()) + monkeypatch.setattr(runner, "_run_app", lambda **_kwargs: response_stream) + monkeypatch.setattr( + "tasks.app_generate.workflow_execute_task._publish_streaming_response", + publish_streaming_response, + ) + + runner.run() + + publish_streaming_response.assert_called_once_with( + response_stream, + exec_params.workflow_run_id, + exec_params.app_mode, + exec_params.workflow_id, + exec_params.args.get("inputs", {}), + WorkflowStartReason.INITIAL, + ) def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(monkeypatch: pytest.MonkeyPatch): @@ -247,6 +692,7 @@ def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversat def test_resume_advanced_chat_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch): generate_entity = _build_advanced_chat_generate_entity(conversation_id="conversation-id") generate_entity.stream = False + workflow = SimpleNamespace(id="workflow-id", created_by="workflow-owner") generator_instance = MagicMock() response_stream = _single_event_generator({"event": "message"}) @@ -271,7 +717,7 @@ def test_resume_advanced_chat_publishes_events_for_originally_blocking_runs(monk _resume_advanced_chat( app_model=SimpleNamespace(id="app-id"), - workflow=SimpleNamespace(created_by="workflow-owner"), + workflow=workflow, user=MagicMock(), conversation=SimpleNamespace(id="conversation-id"), message=MagicMock(), @@ -285,11 +731,19 @@ def test_resume_advanced_chat_publishes_events_for_originally_blocking_runs(monk resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"] assert resumed_entity.stream is True - publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.ADVANCED_CHAT) + publish_streaming_response.assert_called_once_with( + response_stream, + "workflow-run-id", + AppMode.ADVANCED_CHAT, + workflow.id, + generate_entity.inputs, + WorkflowStartReason.RESUMPTION, + ) def test_resume_workflow_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch): generate_entity = _build_workflow_generate_entity(stream=False) + workflow = SimpleNamespace(id="workflow-id", created_by="workflow-owner") generator_instance = MagicMock() response_stream = _single_event_generator({"event": "workflow_finished"}) @@ -316,7 +770,7 @@ def test_resume_workflow_publishes_events_for_originally_blocking_runs(monkeypat _resume_workflow( app_model=SimpleNamespace(id="app-id"), - workflow=SimpleNamespace(created_by="workflow-owner"), + workflow=workflow, user=MagicMock(), generate_entity=generate_entity, graph_runtime_state=MagicMock(), @@ -330,12 +784,20 @@ def test_resume_workflow_publishes_events_for_originally_blocking_runs(monkeypat resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"] assert resumed_entity.stream is True - publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.WORKFLOW) + publish_streaming_response.assert_called_once_with( + response_stream, + "workflow-run-id", + AppMode.WORKFLOW, + workflow.id, + generate_entity.inputs, + WorkflowStartReason.RESUMPTION, + ) workflow_run_repo.delete_workflow_pause.assert_called_once_with(pause_entity) def test_resume_workflow_ignores_missing_old_pause_after_repause(monkeypatch: pytest.MonkeyPatch): generate_entity = _build_workflow_generate_entity(stream=False) + workflow = SimpleNamespace(id="workflow-id", created_by="workflow-owner") generator_instance = MagicMock() response_stream = _single_event_generator({"event": "workflow_paused"}) @@ -363,7 +825,7 @@ def test_resume_workflow_ignores_missing_old_pause_after_repause(monkeypatch: py _resume_workflow( app_model=SimpleNamespace(id="app-id"), - workflow=SimpleNamespace(created_by="workflow-owner"), + workflow=workflow, user=MagicMock(), generate_entity=generate_entity, graph_runtime_state=MagicMock(), @@ -375,5 +837,12 @@ def test_resume_workflow_ignores_missing_old_pause_after_repause(monkeypatch: py pause_entity=pause_entity, ) - publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.WORKFLOW) + publish_streaming_response.assert_called_once_with( + response_stream, + "workflow-run-id", + AppMode.WORKFLOW, + workflow.id, + generate_entity.inputs, + WorkflowStartReason.RESUMPTION, + ) workflow_run_repo.delete_workflow_pause.assert_called_once_with(pause_entity) diff --git a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx index 10624f27574..cd0e4ea8c88 100644 --- a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx @@ -687,6 +687,29 @@ describe('useChat', () => { expect(lastResponse!.workflowProcess?.status).toBe('failed') }) + it('should store workflow finished error on workflow process state', async () => { + let callbacks: HookCallbacks + + vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { + callbacks = options as HookCallbacks + }) + + const { result } = renderHook(() => useChat()) + + act(() => { + result.current.handleSend('test-url', { query: 'failed workflow' }, {}) + }) + + act(() => { + callbacks.onWorkflowStarted({ workflow_run_id: 'wr-err', task_id: 't-err' }) + callbacks.onWorkflowFinished({ data: { status: 'failed', error: 'Invalid upload file' } }) + }) + + const lastResponse = result.current.chatList[1] + expect(lastResponse!.workflowProcess?.status).toBe('failed') + expect(lastResponse!.workflowProcess?.error).toBe('Invalid upload file') + }) + it('should insert and then replace child QA when sending with parent_message_id', () => { let callbacks: HookCallbacks vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { diff --git a/web/app/components/base/chat/chat/answer/__tests__/workflow-process.spec.tsx b/web/app/components/base/chat/chat/answer/__tests__/workflow-process.spec.tsx index 1e7d9f012ee..f8ae9aa3dc5 100644 --- a/web/app/components/base/chat/chat/answer/__tests__/workflow-process.spec.tsx +++ b/web/app/components/base/chat/chat/answer/__tests__/workflow-process.spec.tsx @@ -24,6 +24,21 @@ describe('WorkflowProcessItem', () => { expect(screen.queryByTestId('tracing-panel')).not.toBeInTheDocument() }) + it('should render workflow error message as collapsed title when failed without tracing', () => { + render( + , + ) + + expect(screen.getByTestId('workflow-process-title')).toHaveTextContent('Invalid upload file') + }) + it('should render "Workflow Process" title and TracingPanel when expanded', () => { // We expect t('common.workflowProcess', { ns: 'workflow' }) to be called render() @@ -31,6 +46,21 @@ describe('WorkflowProcessItem', () => { expect(screen.getByTestId('tracing-panel')).toBeInTheDocument() }) + it('should render workflow error message when failed without node tracing details', () => { + render( + , + ) + + expect(screen.getByText('Invalid upload file')).toBeInTheDocument() + }) + it('should toggle collapse state on header click', async () => { const user = userEvent.setup() render() @@ -89,7 +119,7 @@ describe('WorkflowProcessItem', () => { expect(screen.getByTestId('workflow-process-item')).toHaveClass('bg-workflow-process-paused-bg') rerender() - expect(screen.getByTestId('workflow-process-item')).toHaveClass('bg-workflow-process-failed-bg') + expect(screen.getByTestId('workflow-process-item')).toHaveClass('bg-[var(--color-workflow-process-failed-bg)]') }) it('should apply correct background when expanded for different statuses', () => { diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 3a751074e87..0427f9843b0 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -31,6 +31,10 @@ const WorkflowProcessItem = ({ const failed = data.status === WorkflowRunningStatus.Failed || data.status === WorkflowRunningStatus.Stopped const paused = data.status === WorkflowRunningStatus.Paused const latestNode = data.tracing[data.tracing.length - 1] + const fallbackTitle = t('common.workflowProcess', { ns: 'workflow' }) + const collapsedTitle = failed + ? data.error || latestNode?.error || latestNode?.title || fallbackTitle + : latestNode?.title || fallbackTitle useEffect(() => { setCollapse(!expand) @@ -50,7 +54,7 @@ const WorkflowProcessItem = ({ paused && !collapse && 'bg-state-warning-hover', collapse && !failed && !paused && 'bg-workflow-process-bg', collapse && paused && 'bg-workflow-process-paused-bg', - collapse && failed && 'bg-workflow-process-failed-bg', + collapse && failed && 'bg-[var(--color-workflow-process-failed-bg)]', )} data-testid="workflow-process-item" > @@ -92,21 +96,38 @@ const WorkflowProcessItem = ({ ) }
- {!collapse ? t('common.workflowProcess', { ns: 'workflow' }) : latestNode?.title} + {!collapse ? fallbackTitle : collapsedTitle}
{ !collapse && (
- + { + failed && data.error && ( +
+ {data.error} +
+ ) + } + { + data.tracing.length > 0 && ( + + ) + }
) } diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 982c408fb63..212c678b50f 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -405,7 +405,11 @@ export const useChat = ( hasStopRespondedRef.current = false updateChatTreeNode(messageId, (responseItem) => { if (responseItem.workflowProcess && responseItem.workflowProcess.tracing.length > 0) { - responseItem.workflowProcess.status = WorkflowRunningStatus.Running + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: WorkflowRunningStatus.Running, + error: undefined, + } } else { taskIdRef.current = task_id @@ -419,8 +423,13 @@ export const useChat = ( }, onWorkflowFinished: ({ data: workflowFinishedData }) => { updateChatTreeNode(messageId, (responseItem) => { - if (responseItem.workflowProcess) - responseItem.workflowProcess.status = workflowFinishedData.status as WorkflowRunningStatus + if (responseItem.workflowProcess) { + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: workflowFinishedData.status as WorkflowRunningStatus, + error: workflowFinishedData.error, + } + } }) }, onIterationStart: ({ data: iterationStartedData }) => { @@ -971,7 +980,11 @@ export const useChat = ( } if (responseItem.workflowProcess && responseItem.workflowProcess.tracing.length > 0) { - responseItem.workflowProcess.status = WorkflowRunningStatus.Running + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: WorkflowRunningStatus.Running, + error: undefined, + } } else { taskIdRef.current = task_id @@ -991,7 +1004,11 @@ export const useChat = ( onWorkflowFinished: ({ data: workflowFinishedData }) => { if (pausedStateRef.current) pausedStateRef.current = false - responseItem.workflowProcess!.status = workflowFinishedData.status as WorkflowRunningStatus + responseItem.workflowProcess = { + ...responseItem.workflowProcess!, + status: workflowFinishedData.status as WorkflowRunningStatus, + error: workflowFinishedData.error, + } updateCurrentQAOnTree({ placeholderQuestionId, questionItem, diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index 341dd3c6890..02e3113d679 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -38,6 +38,7 @@ export type ChatConfig = Omit & { export type WorkflowProcess = { status: WorkflowRunningStatus tracing: NodeTracing[] + error?: string expand?: boolean // for UI resultText?: string files?: FileEntity[] diff --git a/web/app/components/share/text-generation/result/__tests__/workflow-stream-handlers.spec.ts b/web/app/components/share/text-generation/result/__tests__/workflow-stream-handlers.spec.ts index 88d7769e1d0..7e9d9ea9ccf 100644 --- a/web/app/components/share/text-generation/result/__tests__/workflow-stream-handlers.spec.ts +++ b/web/app/components/share/text-generation/result/__tests__/workflow-stream-handlers.spec.ts @@ -635,6 +635,10 @@ describe('createWorkflowStreamHandlers', () => { message: 'failed', }) expect(failureSetup.onCompleted).toHaveBeenCalledWith('', 3, false) + expect(failureSetup.workflowProcessData()).toEqual(expect.objectContaining({ + status: WorkflowRunningStatus.Failed, + error: 'failed', + })) }) it('should cover existing workflow starts, stopped runs, and non-string outputs', () => { diff --git a/web/app/components/share/text-generation/result/workflow-stream-handlers.ts b/web/app/components/share/text-generation/result/workflow-stream-handlers.ts index 31b2fa283fd..cda0140c540 100644 --- a/web/app/components/share/text-generation/result/workflow-stream-handlers.ts +++ b/web/app/components/share/text-generation/result/workflow-stream-handlers.ts @@ -33,6 +33,7 @@ type CreateWorkflowStreamHandlersParams = { const createInitialWorkflowProcess = (): WorkflowProcess => ({ status: WorkflowRunningStatus.Running, tracing: [], + error: undefined, expand: false, resultText: '', }) @@ -148,9 +149,11 @@ const markNodesStopped = (traces?: WorkflowProcess['tracing']) => { const applyWorkflowFinishedState = ( current: WorkflowProcess | undefined, status: WorkflowRunningStatus, + error?: string, ) => { return updateWorkflowProcess(current, (draft) => { draft.status = status + draft.error = error if ([WorkflowRunningStatus.Stopped, WorkflowRunningStatus.Failed].includes(status)) markNodesStopped(draft.tracing) }) @@ -162,6 +165,7 @@ const applyWorkflowOutputs = ( ) => { return updateWorkflowProcess(current, (draft) => { draft.status = WorkflowRunningStatus.Succeeded + draft.error = undefined draft.files = getFilesInLogs(outputs || []) as unknown as WorkflowProcess['files'] }) } @@ -301,6 +305,7 @@ export const createWorkflowStreamHandlers = ({ setWorkflowProcessData(updateWorkflowProcess(workflowProcessData, (draft) => { draft.expand = true draft.status = WorkflowRunningStatus.Running + draft.error = undefined })) return } @@ -342,14 +347,18 @@ export const createWorkflowStreamHandlers = ({ const workflowStatus = data.status as WorkflowRunningStatus | undefined if (workflowStatus === WorkflowRunningStatus.Stopped) { - setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Stopped)) + setWorkflowProcessData( + applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Stopped, data.error), + ) finishWithFailure() return } if (data.error) { notify({ type: 'error', message: data.error }) - setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Failed)) + setWorkflowProcessData( + applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Failed, data.error), + ) finishWithFailure() return } diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts index e9fe31a909f..859ac821956 100644 --- a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts @@ -262,6 +262,27 @@ describe('useChat – handleResume', () => { const answer = result.current.chatList.find(item => item.id === 'msg-resume') expect(answer!.workflowProcess!.status).toBe('succeeded') }) + + it('should store workflow finished error on resume workflow process', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onWorkflowFinished({ + data: { status: 'failed', error: 'Invalid upload file' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('failed') + expect(answer!.workflowProcess!.error).toBe('Invalid upload file') + }) }) describe('onData', () => { diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts index 1b1c7659fd2..3e2087a7619 100644 --- a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts @@ -522,6 +522,19 @@ describe('useChat – handleSend SSE callbacks', () => { const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) expect(answer!.workflowProcess!.status).toBe('succeeded') }) + + it('should store workflow finished error on workflow process', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onWorkflowFinished({ data: { status: 'failed', error: 'Invalid upload file' } }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('failed') + expect(answer!.workflowProcess!.error).toBe('Invalid upload file') + }) }) describe('onIterationStart / onIterationFinish', () => { diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index d20a3311365..bde9e40ed44 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -61,9 +61,9 @@ export const useChat = ( ) => { const { t } = useTranslation() const { handleRun } = useWorkflowRun() - const hasStopResponded = useRef(false) + const hasStopRespondedRef = useRef(false) const workflowStore = useWorkflowStore() - const conversationId = useRef('') + const conversationIdRef = useRef('') const taskIdRef = useRef('') const [isResponding, setIsResponding] = useState(false) const isRespondingRef = useRef(false) @@ -72,7 +72,7 @@ export const useChat = ( const canRun = useHooksStore(s => s.accessControl.canRun) const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId) const { fetchInspectVars } = useSetWorkflowVarsWithValue() - const [suggestedQuestions, setSuggestQuestions] = useState([]) + const [suggestedQuestions, setSuggestedQuestions] = useState([]) const suggestedQuestionsAbortControllerRef = useRef(null) const { setIterTimes, @@ -190,7 +190,7 @@ export const useChat = ( }, [produceChatTreeNode]) const handleStop = useCallback(() => { - hasStopResponded.current = true + hasStopRespondedRef.current = true handleResponding(false) if (stopChat && taskIdRef.current) stopChat(taskIdRef.current) @@ -203,13 +203,13 @@ export const useChat = ( }, [handleResponding, setIterTimes, setLoopTimes, stopChat]) const handleRestart = useCallback(() => { - conversationId.current = '' + conversationIdRef.current = '' taskIdRef.current = '' handleStop() setIterTimes(DEFAULT_ITER_TIMES) setLoopTimes(DEFAULT_LOOP_TIMES) setChatTree([]) - setSuggestQuestions([]) + setSuggestedQuestions([]) }, [ handleStop, setIterTimes, @@ -353,7 +353,7 @@ export const useChat = ( } if (isFirstMessage && newConversationId) - conversationId.current = newConversationId + conversationIdRef.current = newConversationId taskIdRef.current = taskId if (messageId) @@ -403,17 +403,17 @@ export const useChat = ( return } - if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { + if (config?.suggested_questions_after_answer?.enabled && !hasStopRespondedRef.current && onGetSuggestedQuestions) { try { const { data }: any = await onGetSuggestedQuestions( responseItem.id, newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, ) - setSuggestQuestions(data) + setSuggestedQuestions(data) } // eslint-disable-next-line unused-imports/no-unused-vars catch (error) { - setSuggestQuestions([]) + setSuggestedQuestions([]) } } } @@ -439,7 +439,7 @@ export const useChat = ( onWorkflowStarted: ({ workflow_run_id, task_id, conversation_id, message_id }) => { // If there are no streaming messages, we still need to set the conversation_id to avoid create a new conversation when regeneration in chat-flow. if (conversation_id) { - conversationId.current = conversation_id + conversationIdRef.current = conversation_id } if (message_id && !hasSetResponseId) { questionItem.id = `question-${message_id}` @@ -450,7 +450,11 @@ export const useChat = ( if (responseItem.workflowProcess && responseItem.workflowProcess.tracing.length > 0) { handleResponding(true) - responseItem.workflowProcess.status = WorkflowRunningStatus.Running + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: WorkflowRunningStatus.Running, + error: undefined, + } } else { taskIdRef.current = task_id @@ -468,7 +472,11 @@ export const useChat = ( }) }, onWorkflowFinished: ({ data }) => { - responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus + responseItem.workflowProcess = { + ...responseItem.workflowProcess!, + status: data.status as WorkflowRunningStatus, + error: data.error, + } updateCurrentQAOnTree({ placeholderQuestionId, questionItem, @@ -723,7 +731,7 @@ export const useChat = ( }) if (newConversationId) - conversationId.current = newConversationId + conversationIdRef.current = newConversationId if (taskId) taskIdRef.current = taskId @@ -751,16 +759,16 @@ export const useChat = ( if (hasError) return - if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { + if (config?.suggested_questions_after_answer?.enabled && !hasStopRespondedRef.current && onGetSuggestedQuestions) { try { const { data }: any = await onGetSuggestedQuestions( messageId, newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, ) - setSuggestQuestions(data) + setSuggestedQuestions(data) } catch { - setSuggestQuestions([]) + setSuggestedQuestions([]) } } } @@ -782,10 +790,14 @@ export const useChat = ( }, onWorkflowStarted: ({ workflow_run_id, task_id }) => { handleResponding(true) - hasStopResponded.current = false + hasStopRespondedRef.current = false updateChatTreeNode(messageId, (responseItem) => { if (responseItem.workflowProcess && responseItem.workflowProcess.tracing.length > 0) { - responseItem.workflowProcess.status = WorkflowRunningStatus.Running + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: WorkflowRunningStatus.Running, + error: undefined, + } } else { taskIdRef.current = task_id @@ -799,8 +811,13 @@ export const useChat = ( }, onWorkflowFinished: ({ data: workflowFinishedData }) => { updateChatTreeNode(messageId, (responseItem) => { - if (responseItem.workflowProcess) - responseItem.workflowProcess.status = workflowFinishedData.status as WorkflowRunningStatus + if (responseItem.workflowProcess) { + responseItem.workflowProcess = { + ...responseItem.workflowProcess, + status: workflowFinishedData.status as WorkflowRunningStatus, + error: workflowFinishedData.error, + } + } }) }, onIterationStart: ({ data: iterationStartedData }) => { @@ -1004,7 +1021,7 @@ export const useChat = ( }, [handleResume]) return { - conversationId: conversationId.current, + conversationId: conversationIdRef.current, chatList, setTargetMessageId, handleSwitchSibling,