diff --git a/api/controllers/service_api/app/workflow_events.py b/api/controllers/service_api/app/workflow_events.py
index 58bbbbbd1f..51a977423c 100644
--- a/api/controllers/service_api/app/workflow_events.py
+++ b/api/controllers/service_api/app/workflow_events.py
@@ -36,6 +36,7 @@ class WorkflowEventsApi(Resource):
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": "Whether to replay from persisted state snapshot",
+ "continue_on_pause": "Whether to keep the stream open across workflow_paused events",
}
)
@service_api_ns.doc(
@@ -97,6 +98,8 @@ class WorkflowEventsApi(Resource):
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
+ continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
+ terminal_events = ["workflow_finished"] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
@@ -107,10 +110,15 @@ class WorkflowEventsApi(Resource):
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
+ close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
- msg_generator.retrieve_events(app_mode, workflow_run_entity.id),
+ msg_generator.retrieve_events(
+ app_mode,
+ workflow_run_entity.id,
+ terminal_events=terminal_events,
+ ),
)
event_generator = _generate_stream_events
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index 985ded0f74..2ec0fe32f0 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -39,7 +39,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
-from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
+from core.app.entities.task_entities import (
+ ChatbotAppBlockingResponse,
+ ChatbotAppPausedBlockingResponse,
+ ChatbotAppStreamResponse,
+)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@@ -656,7 +660,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
- ) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
+ ) -> (
+ ChatbotAppBlockingResponse
+ | ChatbotAppPausedBlockingResponse
+ | Generator[ChatbotAppStreamResponse, None, None]
+ ):
"""
Handle response.
:param application_generate_entity: application generate entity
diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py
index fe2702ed69..15b19b2db9 100644
--- a/api/core/app/apps/advanced_chat/generate_response_converter.py
+++ b/api/core/app/apps/advanced_chat/generate_response_converter.py
@@ -3,9 +3,9 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
- AppBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
+ ChatbotAppPausedBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
@@ -14,17 +14,35 @@ from core.app.entities.task_entities import (
PingStreamResponse,
)
-
-class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = ChatbotAppBlockingResponse
+class AdvancedChatAppGenerateResponseConverter(
+ AppGenerateResponseConverter[ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse]
+):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
+ def convert_blocking_full_response(
+ cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
+ ) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
+ if isinstance(blocking_response, ChatbotAppPausedBlockingResponse):
+ paused_data = blocking_response.data.model_dump(mode="json")
+ return {
+ "event": "workflow_paused",
+ "task_id": blocking_response.task_id,
+ "id": blocking_response.data.id,
+ "message_id": blocking_response.data.message_id,
+ "conversation_id": blocking_response.data.conversation_id,
+ "mode": blocking_response.data.mode,
+ "answer": blocking_response.data.answer,
+ "metadata": blocking_response.data.metadata,
+ "created_at": blocking_response.data.created_at,
+ "workflow_run_id": blocking_response.data.workflow_run_id,
+ "data": paused_data,
+ }
+
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
"event": "message",
@@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
+ def convert_blocking_simple_response(
+ cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
+ ) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
- response["metadata"] = cls._get_simple_metadata(metadata)
+ if isinstance(metadata, dict):
+ response["metadata"] = cls._get_simple_metadata(metadata)
return response
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 0ce9ddce9e..bfe1ee789c 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -9,7 +9,7 @@ from datetime import datetime
from threading import Thread
from typing import Any, Union
-from graphon.entities.pause_reason import HumanInputRequired
+from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.utils.encoders import jsonable_encoder
@@ -60,14 +60,17 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
+ ChatbotAppPausedBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
+ HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowTaskState,
+ WorkflowPauseStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
@@ -210,7 +213,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
- def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
+ def process(
+ self,
+ ) -> Union[
+ ChatbotAppBlockingResponse,
+ ChatbotAppPausedBlockingResponse,
+ Generator[ChatbotAppStreamResponse, None, None],
+ ]:
"""
Process generate task pipeline.
:return:
@@ -226,14 +235,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
- def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
+ def _to_blocking_response(
+ self, generator: Generator[StreamResponse, None, None]
+ ) -> Union[ChatbotAppBlockingResponse, ChatbotAppPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
+ human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
+ elif isinstance(stream_response, HumanInputRequiredResponse):
+ human_input_responses.append(stream_response)
+ elif isinstance(stream_response, WorkflowPauseStreamResponse):
+ return ChatbotAppPausedBlockingResponse(
+ task_id=stream_response.task_id,
+ data=ChatbotAppPausedBlockingResponse.Data(
+ id=self._message_id,
+ mode=self._conversation_mode,
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
+ workflow_run_id=stream_response.data.workflow_run_id,
+ answer=self._task_state.answer,
+ metadata=self._message_end_to_stream_response().metadata,
+ created_at=self._message_created_at,
+ paused_nodes=stream_response.data.paused_nodes,
+ reasons=stream_response.data.reasons,
+ status=stream_response.data.status,
+ elapsed_time=stream_response.data.elapsed_time,
+ total_tokens=stream_response.data.total_tokens,
+ total_steps=stream_response.data.total_steps,
+ ),
+ )
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@@ -254,8 +288,42 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
+ if human_input_responses:
+ return self._build_paused_blocking_response_from_human_input(human_input_responses)
+
raise ValueError("queue listening stopped unexpectedly.")
+ def _build_paused_blocking_response_from_human_input(
+ self, human_input_responses: list[HumanInputRequiredResponse]
+ ) -> ChatbotAppPausedBlockingResponse:
+ runtime_state = self._resolve_graph_runtime_state()
+ paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
+ reasons = []
+ for response in human_input_responses:
+ reason = response.data.model_dump(mode="json")
+ reason["type"] = PauseReasonType.HUMAN_INPUT_REQUIRED
+ reasons.append(reason)
+
+ return ChatbotAppPausedBlockingResponse(
+ task_id=self._application_generate_entity.task_id,
+ data=ChatbotAppPausedBlockingResponse.Data(
+ id=self._message_id,
+ mode=self._conversation_mode,
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
+ workflow_run_id=human_input_responses[-1].workflow_run_id,
+ answer=self._task_state.answer,
+ metadata=self._message_end_to_stream_response().metadata,
+ created_at=self._message_created_at,
+ paused_nodes=paused_nodes,
+ reasons=reasons,
+ status=WorkflowExecutionStatus.PAUSED,
+ elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
+ total_tokens=runtime_state.total_tokens,
+ total_steps=runtime_state.node_run_steps,
+ ),
+ )
+
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py
index 731c6ee12e..15984aff03 100644
--- a/api/core/app/apps/agent_chat/generate_response_converter.py
+++ b/api/core/app/apps/agent_chat/generate_response_converter.py
@@ -12,11 +12,10 @@ from core.app.entities.task_entities import (
)
-class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = ChatbotAppBlockingResponse
+class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +36,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py
index 406d07927e..b04f811b46 100644
--- a/api/core/app/apps/base_app_generate_response_converter.py
+++ b/api/core/app/apps/base_app_generate_response_converter.py
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
-from typing import Any, Union
+from typing import Any, Generic, TypeVar, Union, cast
from graphon.model_runtime.errors.invoke import InvokeError
@@ -12,8 +12,13 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
logger = logging.getLogger(__name__)
-class AppGenerateResponseConverter(ABC):
- _blocking_response_type: type[AppBlockingResponse]
+TBlockingResponse = TypeVar("TBlockingResponse", bound=AppBlockingResponse)
+
+
+class AppGenerateResponseConverter(Generic[TBlockingResponse], ABC):
+ @classmethod
+ def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
+ return cast(TBlockingResponse, response)
@classmethod
def convert(
@@ -21,7 +26,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
- return cls.convert_blocking_full_response(response)
+ return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@@ -30,7 +35,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
- return cls.convert_blocking_simple_response(response)
+ return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@@ -40,12 +45,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
- def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
+ def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
- def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
+ def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py
index 3d0375151d..7c19981a82 100644
--- a/api/core/app/apps/chat/generate_response_converter.py
+++ b/api/core/app/apps/chat/generate_response_converter.py
@@ -12,11 +12,10 @@ from core.app.entities.task_entities import (
)
-class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = ChatbotAppBlockingResponse
+class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +36,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
diff --git a/api/core/app/apps/common/pause_reason_serializer.py b/api/core/app/apps/common/pause_reason_serializer.py
new file mode 100644
index 0000000000..ef9ce5d05b
--- /dev/null
+++ b/api/core/app/apps/common/pause_reason_serializer.py
@@ -0,0 +1,17 @@
+from collections.abc import Mapping
+from typing import Any
+
+from graphon.entities.pause_reason import PauseReason
+
+
+def pause_reason_to_public_dict(reason: PauseReason | Mapping[str, Any]) -> dict[str, Any]:
+ if isinstance(reason, Mapping):
+ data = dict(reason)
+ else:
+ data = dict(reason.model_dump(mode="json"))
+
+ discriminator = data.pop("TYPE", None)
+ if discriminator is not None:
+ data["type"] = discriminator
+
+ return data
diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py
index a515531616..47f7b6c1e2 100644
--- a/api/core/app/apps/common/workflow_response_converter.py
+++ b/api/core/app/apps/common/workflow_response_converter.py
@@ -22,6 +22,7 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from sqlalchemy import select
from sqlalchemy.orm import Session
+from core.app.apps.common.pause_reason_serializer import pause_reason_to_public_dict
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
@@ -317,7 +318,7 @@ class WorkflowResponseConverter:
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
- pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
+ pause_reasons = [pause_reason_to_public_dict(reason) for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
@@ -338,6 +339,21 @@ class WorkflowResponseConverter:
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
+ for pause_reason in pause_reasons:
+ if pause_reason.get("type") != "human_input_required":
+ continue
+
+ form_id = pause_reason.get("form_id")
+ if not isinstance(form_id, str):
+ continue
+
+ expiration_time = expiration_times_by_form_id.get(form_id)
+ if expiration_time is None:
+ raise ValueError(f"HumanInputForm not found for pause reason, form_id={form_id}")
+
+ pause_reason["form_token"] = form_token_by_form_id.get(form_id)
+ pause_reason["expiration_time"] = int(expiration_time.timestamp())
+
responses: list[StreamResponse] = []
for reason in event.reasons:
diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py
index 71886b39ba..505a6d4507 100644
--- a/api/core/app/apps/completion/generate_response_converter.py
+++ b/api/core/app/apps/completion/generate_response_converter.py
@@ -12,11 +12,10 @@ from core.app.entities.task_entities import (
)
-class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = CompletionAppBlockingResponse
+class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +35,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py
index 68631bb230..8b9290a217 100644
--- a/api/core/app/apps/message_generator.py
+++ b/api/core/app/apps/message_generator.py
@@ -1,5 +1,6 @@
-from collections.abc import Callable, Generator, Mapping
+from collections.abc import Callable, Generator, Iterable, Mapping
+from core.app.entities.task_entities import StreamEvent
from core.app.apps.streaming_utils import stream_topic_events
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
@@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
+ terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
+ terminal_events=terminal_events,
)
diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py
index 02b3160b7c..d12efbc298 100644
--- a/api/core/app/apps/pipeline/generate_response_converter.py
+++ b/api/core/app/apps/pipeline/generate_response_converter.py
@@ -13,11 +13,10 @@ from core.app.entities.task_entities import (
)
-class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = WorkflowAppBlockingResponse
+class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
+ def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -26,7 +25,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
+ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 6074e81d1e..116487d15b 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -29,7 +29,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
-from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
+from core.app.entities.task_entities import (
+ WorkflowAppBlockingResponse,
+ WorkflowAppPausedBlockingResponse,
+ WorkflowAppStreamResponse,
+)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
- ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
+ ) -> (
+ WorkflowAppBlockingResponse
+ | WorkflowAppPausedBlockingResponse
+ | Generator[WorkflowAppStreamResponse, None, None]
+ ):
"""
Handle response.
:param application_generate_entity: application generate entity
diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py
index c69826cbef..8f482aefca 100644
--- a/api/core/app/apps/workflow/generate_response_converter.py
+++ b/api/core/app/apps/workflow/generate_response_converter.py
@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
+from typing import Any, cast
+
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@@ -9,24 +11,30 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
+ WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
-class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
- _blocking_response_type = WorkflowAppBlockingResponse
+class WorkflowAppGenerateResponseConverter(
+ AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
+):
@classmethod
- def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_full_response(
+ cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
+ ) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
- return blocking_response.model_dump()
+ return dict(blocking_response.model_dump())
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
+ def convert_blocking_simple_response(
+ cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
+ ) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -58,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
- response_chunk.update(data)
+ response_chunk.update(cast(dict[str, object], data))
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@@ -87,9 +95,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
- response_chunk.update(data)
+ response_chunk.update(cast(dict[str, object], data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
- response_chunk.update(sub_stream_response.to_ignore_detail_dict())
+ response_chunk.update(cast(dict[str, object], sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index 96387133b1..1c60f5347d 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -45,12 +45,14 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
+ HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
+ WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@@ -118,7 +120,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
- def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
+ def process(
+ self,
+ ) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@@ -129,19 +133,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
- def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
+ def _to_blocking_response(
+ self, generator: Generator[StreamResponse, None, None]
+ ) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
+ human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
+ elif isinstance(stream_response, HumanInputRequiredResponse):
+ human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
- response = WorkflowAppBlockingResponse(
+ response = WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
- data=WorkflowAppBlockingResponse.Data(
+ data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@@ -152,6 +161,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
+ paused_nodes=stream_response.data.paused_nodes,
+ reasons=stream_response.data.reasons,
),
)
@@ -178,8 +189,41 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
+ if human_input_responses:
+ return self._build_paused_blocking_response_from_human_input(human_input_responses)
+
raise ValueError("queue listening stopped unexpectedly.")
+ def _build_paused_blocking_response_from_human_input(
+ self, human_input_responses: list[HumanInputRequiredResponse]
+ ) -> WorkflowAppPausedBlockingResponse:
+ runtime_state = self._resolve_graph_runtime_state()
+ paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
+ reasons = []
+ for response in human_input_responses:
+ reason = response.data.model_dump(mode="json")
+ reason["type"] = "human_input_required"
+ reasons.append(reason)
+
+ return WorkflowAppPausedBlockingResponse(
+ task_id=self._application_generate_entity.task_id,
+ workflow_run_id=human_input_responses[-1].workflow_run_id,
+ data=WorkflowAppPausedBlockingResponse.Data(
+ id=human_input_responses[-1].workflow_run_id,
+ workflow_id=self._workflow.id,
+ status=WorkflowExecutionStatus.PAUSED,
+ outputs={},
+ error=None,
+ elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
+ total_tokens=runtime_state.total_tokens,
+ total_steps=runtime_state.node_run_steps,
+ created_at=int(runtime_state.start_at),
+ finished_at=None,
+ paused_nodes=paused_nodes,
+ reasons=reasons,
+ ),
+ )
+
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index 88faf235d1..b6cd5ed5f9 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -774,6 +774,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
+class ChatbotAppPausedBlockingResponse(AppBlockingResponse):
+ """
+ ChatbotAppPausedBlockingResponse entity
+ """
+
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+
+ id: str
+ mode: str
+ conversation_id: str
+ message_id: str
+ workflow_run_id: str
+ answer: str
+ metadata: Mapping[str, object] = Field(default_factory=dict)
+ created_at: int
+ paused_nodes: Sequence[str] = Field(default_factory=list)
+ reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
+ status: WorkflowExecutionStatus
+ elapsed_time: float
+ total_tokens: int
+ total_steps: int
+
+ data: Data
+
+
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@@ -819,6 +847,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
+class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
+ """
+ WorkflowAppPausedBlockingResponse entity
+ """
+
+ class Data(BaseModel):
+ """
+ Data entity
+ """
+
+ id: str
+ workflow_id: str
+ status: WorkflowExecutionStatus
+ outputs: Mapping[str, Any] | None = None
+ error: str | None = None
+ elapsed_time: float
+ total_tokens: int
+ total_steps: int
+ created_at: int
+ finished_at: int | None
+ paused_nodes: Sequence[str] = Field(default_factory=list)
+ reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
+
+ workflow_run_id: str
+ data: Data
+
+
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index b760696c5e..b2ccd1d8f4 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom
-from models.human_input import HumanInputForm
+from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity
@@ -60,9 +60,20 @@ class _WorkflowRunError(Exception):
pass
+def _select_recipient_token(
+ recipients: Sequence[HumanInputFormRecipient],
+ recipient_type: RecipientType,
+) -> str | None:
+ recipient = next((recipient for recipient in recipients if recipient.recipient_type == recipient_type), None)
+ if recipient is None or not recipient.access_token:
+ return None
+ return recipient.access_token
+
+
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
+ recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired:
form_content = ""
inputs = []
@@ -89,6 +100,12 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
+ # Service API pause payloads and replayed workflow events must expose the public token used by
+ # `/form/human_input/:form_token`, so prefer the standalone web-app surface and only fall back
+ # to the console token when a web-app token is unavailable.
+ form_token = _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) or _select_recipient_token(
+ recipients, RecipientType.CONSOLE
+ )
return HumanInputRequired(
form_id=form_id,
form_content=form_content,
@@ -96,6 +113,7 @@ def _build_human_input_required_reason(
actions=actions,
node_id=node_id,
node_title=node_title,
+ form_token=form_token,
resolved_default_values=resolved_default_values,
)
@@ -804,12 +822,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
+ recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
+ if form_ids:
+ recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
+ for recipient in session.scalars(recipient_stmt).all():
+ recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
- pause_reasons.append(_build_human_input_required_reason(reason, form_model))
+ pause_reasons.append(
+ _build_human_input_required_reason(
+ reason,
+ form_model,
+ recipients_by_form_id.get(reason.form_id, ()),
+ )
+ )
else:
pause_reasons.append(reason.to_entity())
return pause_reasons
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index 5e8c7aa337..8ff53d143b 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -162,6 +162,7 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=True,
call_depth=0,
+ workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
@@ -183,6 +184,10 @@ class AppGenerateService:
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
+ pause_config = PauseStateLayerConfig(
+ session_factory=session_factory.get_session_maker(),
+ state_owner_user_id=workflow.created_by,
+ )
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
@@ -194,6 +199,7 @@ class AppGenerateService:
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
+ pause_state_config=pause_config,
)
),
request_id=request_id,
diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py
index 601e9261fc..11ca99361d 100644
--- a/api/services/workflow_event_snapshot_service.py
+++ b/api/services/workflow_event_snapshot_service.py
@@ -16,8 +16,10 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, sessionmaker
+from core.app.apps.common.pause_reason_serializer import pause_reason_to_public_dict
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
+ HumanInputRequiredResponse,
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@@ -26,6 +28,8 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
+from core.workflow.human_input_forms import load_form_tokens_by_form_id
+from models.human_input import HumanInputForm
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
@@ -61,6 +65,7 @@ def build_workflow_event_stream(
session_maker: sessionmaker[Session],
idle_timeout: float = 300,
ping_interval: float = 10.0,
+ close_on_pause: bool = True,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@@ -115,13 +120,14 @@ def build_workflow_event_stream(
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
+ session_maker=session_maker,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
- if _is_terminal_event(event, include_paused=True):
+ if _is_terminal_event(event, close_on_pause=close_on_pause):
return
while True:
@@ -146,7 +152,7 @@ def build_workflow_event_stream(
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
- if _is_terminal_event(event, include_paused=True):
+ if _is_terminal_event(event, close_on_pause=close_on_pause):
return
finally:
buffer_state.stop_event.set()
@@ -207,6 +213,7 @@ def _build_snapshot_events(
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
+ session_maker: sessionmaker[Session] | None = None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
@@ -241,12 +248,22 @@ def _build_snapshot_events(
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
+ for human_input_event in _build_human_input_required_events(
+ workflow_run_id=workflow_run.id,
+ task_id=task_id,
+ pause_entity=pause_entity,
+ session_maker=session_maker,
+ ):
+ _apply_message_context(human_input_event, message_context)
+ events.append(human_input_event)
+
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
+ session_maker=session_maker,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
@@ -314,6 +331,78 @@ def _build_node_started_event(
return response.to_ignore_detail_dict()
+def _build_human_input_required_events(
+ *,
+ workflow_run_id: str,
+ task_id: str,
+ pause_entity: WorkflowPauseEntity,
+ session_maker: sessionmaker[Session] | None,
+) -> list[dict[str, Any]]:
+ reasons = [pause_reason_to_public_dict(reason) for reason in pause_entity.get_pause_reasons()]
+ human_input_form_ids = [
+ form_id
+ for reason in reasons
+ if reason.get("type") == "human_input_required"
+ for form_id in [reason.get("form_id")]
+ if isinstance(form_id, str)
+ ]
+
+ expiration_times_by_form_id: dict[str, int] = {}
+ display_in_ui_by_form_id: dict[str, bool] = {}
+ form_tokens_by_form_id: dict[str, str] = {}
+ if human_input_form_ids and session_maker is not None:
+ stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
+ HumanInputForm.id.in_(human_input_form_ids)
+ )
+ with session_maker() as session:
+ for form_id, expiration_time, form_definition in session.execute(stmt):
+ expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
+ try:
+ definition_payload = json.loads(form_definition) if form_definition else {}
+ except (TypeError, json.JSONDecodeError):
+ definition_payload = {}
+ display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
+ form_tokens_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
+
+ events: list[dict[str, Any]] = []
+ for reason in reasons:
+ if reason.get("type") != "human_input_required":
+ continue
+
+ form_id = reason.get("form_id")
+ node_id = reason.get("node_id")
+ node_title = reason.get("node_title")
+ form_content = reason.get("form_content")
+ if not all(isinstance(value, str) for value in (form_id, node_id, node_title, form_content)):
+ continue
+
+ expiration_time = expiration_times_by_form_id.get(form_id)
+ if expiration_time is None:
+ continue
+
+ response = HumanInputRequiredResponse(
+ task_id=task_id,
+ workflow_run_id=workflow_run_id,
+ data=HumanInputRequiredResponse.Data(
+ form_id=form_id,
+ node_id=node_id,
+ node_title=node_title,
+ form_content=form_content,
+ inputs=reason.get("inputs") or [],
+ actions=reason.get("actions") or [],
+ display_in_ui=display_in_ui_by_form_id.get(form_id, False),
+ form_token=form_tokens_by_form_id.get(form_id),
+ resolved_default_values=reason.get("resolved_default_values") or {},
+ expiration_time=expiration_time,
+ ),
+ )
+ payload = response.model_dump(mode="json")
+ payload["event"] = response.event.value
+ events.append(payload)
+
+ return events
+
+
def _build_node_finished_event(
*,
workflow_run_id: str,
@@ -356,6 +445,7 @@ def _build_pause_event(
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
+ session_maker: sessionmaker[Session] | None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
@@ -364,7 +454,24 @@ def _build_pause_event(
paused_nodes = state.get_paused_nodes()
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
- reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
+ reasons = [pause_reason_to_public_dict(reason) for reason in pause_entity.get_pause_reasons()]
+ human_input_form_ids = [
+ form_id
+ for reason in reasons
+ if reason.get("type") == "human_input_required"
+ for form_id in [reason.get("form_id")]
+ if isinstance(form_id, str)
+ ]
+ if human_input_form_ids and session_maker is not None:
+ with session_maker() as session:
+ form_tokens_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
+ for reason in reasons:
+ if reason.get("type") != "human_input_required":
+ continue
+ form_id = reason.get("form_id")
+ if isinstance(form_id, str):
+ reason["form_token"] = form_tokens_by_form_id.get(form_id)
+
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
@@ -449,12 +556,12 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
return event
-def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
+def _is_terminal_event(event: Mapping[str, Any] | str, close_on_pause: bool = True) -> bool:
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
- if include_paused:
+ if close_on_pause:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False
diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py
index 8f2f5f261e..fbdca4ae3d 100644
--- a/api/tasks/app_generate/workflow_execute_task.py
+++ b/api/tasks/app_generate/workflow_execute_task.py
@@ -399,6 +399,8 @@ def _resume_advanced_chat(
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
+ resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
+
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@@ -426,7 +428,7 @@ def _resume_advanced_chat(
user=user,
conversation=conversation,
message=message,
- application_generate_entity=generate_entity,
+ application_generate_entity=resumed_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
@@ -436,9 +438,8 @@ def _resume_advanced_chat(
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
- if generate_entity.stream:
- assert isinstance(response, Generator)
- _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
+ assert isinstance(response, Generator)
+ _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
def _resume_workflow(
@@ -455,6 +456,8 @@ def _resume_workflow(
workflow_run_repo,
pause_entity,
) -> None:
+ resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
+
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@@ -480,7 +483,7 @@ def _resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
- application_generate_entity=generate_entity,
+ application_generate_entity=resumed_generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
@@ -490,9 +493,8 @@ def _resume_workflow(
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
- if generate_entity.stream:
- assert isinstance(response, Generator)
- _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
+ assert isinstance(response, Generator)
+ _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
index 64c93ac07c..6dcbbe064e 100644
--- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
+++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
@@ -628,12 +628,12 @@ class TestPrivateWorkflowPauseEntity:
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
- def test_builds_reason_from_form_definition(
+ def test_prefers_standalone_web_app_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
- """Build the graph pause reason from the stored form definition."""
+ """Use the public standalone web-app token for service API payloads."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
@@ -660,6 +660,40 @@ class TestBuildHumanInputRequiredReason:
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
+ delivery = HumanInputDelivery(
+ form_id=form_model.id,
+ delivery_method_type=DeliveryMethodType.WEBAPP,
+ channel_payload="{}",
+ )
+ db_session_with_containers.add(delivery)
+ db_session_with_containers.flush()
+
+ backstage_access_token = secrets.token_urlsafe(8)
+ backstage_recipient = HumanInputFormRecipient(
+ form_id=form_model.id,
+ delivery_id=delivery.id,
+ recipient_type=RecipientType.BACKSTAGE,
+ recipient_payload=BackstageRecipientPayload().model_dump_json(),
+ access_token=backstage_access_token,
+ )
+ console_access_token = secrets.token_urlsafe(8)
+ console_recipient = HumanInputFormRecipient(
+ form_id=form_model.id,
+ delivery_id=delivery.id,
+ recipient_type=RecipientType.CONSOLE,
+ recipient_payload="{}",
+ access_token=console_access_token,
+ )
+ web_app_access_token = secrets.token_urlsafe(8)
+ web_app_recipient = HumanInputFormRecipient(
+ form_id=form_model.id,
+ delivery_id=delivery.id,
+ recipient_type=RecipientType.STANDALONE_WEB_APP,
+ recipient_payload="{}",
+ access_token=web_app_access_token,
+ )
+ db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
+ db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
@@ -688,12 +722,109 @@ class TestBuildHumanInputRequiredReason:
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
+ db_session_with_containers.refresh(backstage_recipient)
+ db_session_with_containers.refresh(console_recipient)
+ db_session_with_containers.refresh(web_app_recipient)
- reason = _build_human_input_required_reason(reason_model, form_model)
+ reason = _build_human_input_required_reason(
+ reason_model,
+ form_model,
+ [backstage_recipient, console_recipient, web_app_recipient],
+ )
assert isinstance(reason, HumanInputRequired)
+ assert reason.form_token == web_app_access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"
assert reason.resolved_default_values == {"name": "Alice"}
+
+ def test_falls_back_to_console_token_when_web_app_token_missing(
+ self,
+ db_session_with_containers: Session,
+ test_scope: _TestScope,
+ ) -> None:
+ """Use the console token only when no standalone web-app token exists."""
+
+ expiration_time = naive_utc_now()
+ form_definition = FormDefinition(
+ form_content="content",
+ inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
+ user_actions=[UserAction(id="approve", title="Approve")],
+ rendered_content="rendered",
+ expiration_time=expiration_time,
+ default_values={"name": "Alice"},
+ node_title="Ask Name",
+ display_in_ui=True,
+ )
+
+ form_model = HumanInputForm(
+ tenant_id=test_scope.tenant_id,
+ app_id=test_scope.app_id,
+ workflow_run_id=str(uuid4()),
+ node_id="node-1",
+ form_definition=form_definition.model_dump_json(),
+ rendered_content="rendered",
+ status=HumanInputFormStatus.WAITING,
+ expiration_time=expiration_time,
+ )
+ db_session_with_containers.add(form_model)
+ db_session_with_containers.flush()
+
+ delivery = HumanInputDelivery(
+ form_id=form_model.id,
+ delivery_method_type=DeliveryMethodType.WEBAPP,
+ channel_payload="{}",
+ )
+ db_session_with_containers.add(delivery)
+ db_session_with_containers.flush()
+
+ backstage_access_token = secrets.token_urlsafe(8)
+ backstage_recipient = HumanInputFormRecipient(
+ form_id=form_model.id,
+ delivery_id=delivery.id,
+ recipient_type=RecipientType.BACKSTAGE,
+ recipient_payload=BackstageRecipientPayload().model_dump_json(),
+ access_token=backstage_access_token,
+ )
+ console_access_token = secrets.token_urlsafe(8)
+ console_recipient = HumanInputFormRecipient(
+ form_id=form_model.id,
+ delivery_id=delivery.id,
+ recipient_type=RecipientType.CONSOLE,
+ recipient_payload="{}",
+ access_token=console_access_token,
+ )
+ db_session_with_containers.add_all([backstage_recipient, console_recipient])
+ db_session_with_containers.flush()
+
+ workflow_run = _create_workflow_run(
+ db_session_with_containers,
+ test_scope,
+ status=WorkflowExecutionStatus.RUNNING,
+ )
+ pause = WorkflowPause(
+ id=str(uuid4()),
+ workflow_id=test_scope.workflow_id,
+ workflow_run_id=workflow_run.id,
+ state_object_key=f"workflow-state-{uuid4()}.json",
+ )
+ db_session_with_containers.add(pause)
+ db_session_with_containers.flush()
+ test_scope.state_keys.add(pause.state_object_key)
+
+ reason_model = WorkflowPauseReason(
+ pause_id=pause.id,
+ type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
+ form_id=form_model.id,
+ node_id="node-1",
+ message="",
+ )
+ db_session_with_containers.add(reason_model)
+ db_session_with_containers.commit()
+
+ reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
+
+ assert isinstance(reason, HumanInputRequired)
+ assert reason.form_token == console_access_token
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py
new file mode 100644
index 0000000000..d732367783
--- /dev/null
+++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py
@@ -0,0 +1,701 @@
+"""Dedicated tests for HITL behavior exposed through the Service API."""
+
+from __future__ import annotations
+
+import sys
+import json
+import queue
+from collections.abc import Sequence
+from dataclasses import dataclass
+from datetime import UTC, datetime
+from threading import Event
+from types import SimpleNamespace
+from unittest.mock import ANY, MagicMock, Mock
+
+import pytest
+
+import services.app_generate_service as ags_module
+from controllers.service_api.app.workflow_events import WorkflowEventsApi
+from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
+from core.app.apps.common import workflow_response_converter
+from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
+from core.app.entities.queue_entities import QueueWorkflowPausedEvent
+from core.app.entities.task_entities import (
+ ChatbotAppPausedBlockingResponse,
+ HumanInputRequiredResponse,
+ WorkflowAppPausedBlockingResponse,
+ WorkflowPauseStreamResponse,
+)
+from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
+from graphon.entities import WorkflowStartReason
+from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
+from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
+from graphon.nodes.human_input.entities import FormInput, UserAction
+from graphon.nodes.human_input.enums import FormInputType
+from graphon.runtime import GraphRuntimeState, VariablePool
+from core.workflow.system_variables import build_system_variables
+from models.account import Account
+from models.enums import CreatorUserRole
+from models.model import AppMode
+from models.workflow import WorkflowRun
+from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
+from repositories.entities.workflow_pause import WorkflowPauseEntity
+from services.app_generate_service import AppGenerateService
+from services.workflow_event_snapshot_service import _build_snapshot_events
+from tests.unit_tests.controllers.service_api.conftest import _unwrap
+
+
+class _DummyRateLimit:
+ @staticmethod
+ def gen_request_key() -> str:
+ return "dummy-request-id"
+
+ def __init__(self, client_id: str, max_active_requests: int) -> None:
+ self.client_id = client_id
+ self.max_active_requests = max_active_requests
+
+ def enter(self, request_id: str | None = None) -> str:
+ return request_id or "dummy-request-id"
+
+ def exit(self, request_id: str) -> None:
+ return None
+
+ def generate(self, generator, request_id: str):
+ return generator
+
+
+def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
+ workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
+ repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
+ monkeypatch.setattr(
+ workflow_events_module.DifyAPIRepositoryFactory,
+ "create_api_workflow_run_repository",
+ lambda *_args, **_kwargs: repo,
+ )
+ monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
+ return workflow_events_module
+
+
+def _build_service_api_pause_converter() -> WorkflowResponseConverter:
+ application_generate_entity = SimpleNamespace(
+ inputs={},
+ files=[],
+ invoke_from=InvokeFrom.SERVICE_API,
+ app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
+ )
+ system_variables = build_system_variables(
+ user_id="user",
+ app_id="app-id",
+ workflow_id="workflow-id",
+ workflow_execution_id="run-id",
+ )
+ user = MagicMock(spec=Account)
+ user.id = "account-id"
+ user.name = "Tester"
+ user.email = "tester@example.com"
+ return WorkflowResponseConverter(
+ application_generate_entity=application_generate_entity,
+ user=user,
+ system_variables=system_variables,
+ )
+
+
+def _build_advanced_chat_paused_blocking_response() -> ChatbotAppPausedBlockingResponse:
+ data = ChatbotAppPausedBlockingResponse.Data(
+ id="msg-1",
+ mode="chat",
+ conversation_id="c1",
+ message_id="m1",
+ workflow_run_id="run-1",
+ answer="partial",
+ metadata={"usage": {"total_tokens": 1}},
+ created_at=1,
+ paused_nodes=["node-1"],
+ reasons=[
+ {
+ "type": PauseReasonType.HUMAN_INPUT_REQUIRED,
+ "form_id": "form-1",
+ "expiration_time": 100,
+ }
+ ],
+ status=WorkflowExecutionStatus.PAUSED,
+ elapsed_time=0.1,
+ total_tokens=0,
+ total_steps=0,
+ )
+ return ChatbotAppPausedBlockingResponse(task_id="t1", data=data)
+
+
+def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
+ return WorkflowAppPausedBlockingResponse(
+ task_id="t1",
+ workflow_run_id="r1",
+ data=WorkflowAppPausedBlockingResponse.Data(
+ id="r1",
+ workflow_id="wf-1",
+ status=WorkflowExecutionStatus.PAUSED,
+ outputs={},
+ error=None,
+ elapsed_time=0.5,
+ total_tokens=0,
+ total_steps=2,
+ created_at=1,
+ finished_at=None,
+ paused_nodes=["node-1"],
+ reasons=[{"type": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
+ ),
+ )
+
+
+@dataclass(frozen=True)
+class _FakePauseEntity(WorkflowPauseEntity):
+ pause_id: str
+ workflow_run_id: str
+ paused_at_value: datetime
+ pause_reasons: Sequence[HumanInputRequired]
+
+ @property
+ def id(self) -> str:
+ return self.pause_id
+
+ @property
+ def workflow_execution_id(self) -> str:
+ return self.workflow_run_id
+
+ def get_state(self) -> bytes:
+ raise AssertionError("state is not required for snapshot tests")
+
+ @property
+ def resumed_at(self) -> datetime | None:
+ return None
+
+ @property
+ def paused_at(self) -> datetime:
+ return self.paused_at_value
+
+ def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
+ return self.pause_reasons
+
+
+def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
+ return WorkflowRun(
+ id="run-1",
+ tenant_id="tenant-1",
+ app_id="app-1",
+ workflow_id="workflow-1",
+ type="workflow",
+ triggered_from="app-run",
+ version="v1",
+ graph=None,
+ inputs=json.dumps({"input": "value"}),
+ status=status,
+ outputs=json.dumps({}),
+ error=None,
+ elapsed_time=0.0,
+ total_tokens=0,
+ total_steps=0,
+ created_by_role=CreatorUserRole.END_USER,
+ created_by="user-1",
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
+ )
+
+
+def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
+ created_at = datetime(2024, 1, 1, tzinfo=UTC)
+ finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
+ return WorkflowNodeExecutionSnapshot(
+ execution_id="exec-1",
+ node_id="node-1",
+ node_type="human-input",
+ title="Human Input",
+ index=1,
+ status=status.value,
+ elapsed_time=0.5,
+ created_at=created_at,
+ finished_at=finished_at,
+ iteration_id=None,
+ loop_id=None,
+ )
+
+
+def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
+ app_config = WorkflowUIBasedAppConfig(
+ tenant_id="tenant-1",
+ app_id="app-1",
+ app_mode=AppMode.WORKFLOW,
+ workflow_id="workflow-1",
+ )
+ generate_entity = WorkflowAppGenerateEntity(
+ task_id=task_id,
+ app_config=app_config,
+ inputs={},
+ files=[],
+ user_id="user-1",
+ stream=True,
+ invoke_from=InvokeFrom.EXPLORE,
+ call_depth=0,
+ workflow_execution_id="run-1",
+ )
+ runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
+ runtime_state.register_paused_node("node-1")
+ runtime_state.outputs = {"result": "value"}
+ wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
+ return WorkflowResumptionContext(
+ generate_entity=wrapper,
+ serialized_graph_runtime_state=runtime_state.dumps(),
+ )
+
+
+class TestHitlServiceApi:
+ # Service API event-stream continuation
+ def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
+ workflow_run = SimpleNamespace(
+ id="run-1",
+ app_id="app-1",
+ created_by_role=CreatorUserRole.END_USER,
+ created_by="end-user-1",
+ finished_at=None,
+ )
+ workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
+ msg_generator = Mock()
+ msg_generator.retrieve_events.return_value = ["raw-event"]
+ workflow_generator = Mock()
+ workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
+ monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
+ monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
+
+ api = WorkflowEventsApi()
+ handler = _unwrap(api.get)
+ app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
+ end_user = SimpleNamespace(id="end-user-1")
+
+ with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
+ response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
+
+ assert response.get_data(as_text=True) == "data: streamed\n\n"
+ msg_generator.retrieve_events.assert_called_once_with(
+ AppMode.WORKFLOW,
+ "run-1",
+ terminal_events=["workflow_finished"],
+ )
+ workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
+
+ def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
+ self, app, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ workflow_run = SimpleNamespace(
+ id="run-1",
+ app_id="app-1",
+ created_by_role=CreatorUserRole.END_USER,
+ created_by="end-user-1",
+ finished_at=None,
+ )
+ workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
+ msg_generator = Mock()
+ workflow_generator = Mock()
+ workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
+ snapshot_builder = Mock(return_value=["snapshot-events"])
+ monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
+ monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
+ monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
+
+ api = WorkflowEventsApi()
+ handler = _unwrap(api.get)
+ app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
+ end_user = SimpleNamespace(id="end-user-1")
+
+ with app.test_request_context(
+ "/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
+ method="GET",
+ ):
+ response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
+
+ assert response.get_data(as_text=True) == "data: snapshot\n\n"
+ msg_generator.retrieve_events.assert_not_called()
+ snapshot_builder.assert_called_once_with(
+ app_mode=AppMode.WORKFLOW,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=ANY,
+ close_on_pause=False,
+ )
+ workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
+
+ def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
+ monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
+ monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
+
+ workflow = MagicMock()
+ workflow.created_by = "owner-id"
+ monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
+ monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
+
+ generator_instance = MagicMock()
+ generator_instance.generate.return_value = {"result": "advanced-blocking"}
+ generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
+ monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
+
+ app_model = MagicMock()
+ app_model.mode = AppMode.ADVANCED_CHAT
+ app_model.id = "app-id"
+ app_model.tenant_id = "tenant-id"
+ app_model.max_active_requests = 0
+ app_model.is_agent = False
+
+ user = MagicMock()
+ user.id = "user-id"
+
+ result = AppGenerateService.generate(
+ app_model=app_model,
+ user=user,
+ args={"workflow_id": None, "query": "hi", "inputs": {}},
+ invoke_from=InvokeFrom.SERVICE_API,
+ streaming=False,
+ )
+
+ assert result == {"result": "advanced-blocking"}
+ call_kwargs = generator_instance.generate.call_args.kwargs
+ assert call_kwargs["streaming"] is False
+ assert call_kwargs["pause_state_config"] is not None
+ assert call_kwargs["pause_state_config"].session_factory == "session-maker"
+ assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
+
+ # Blocking payload contract
+ def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
+ from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
+
+ response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
+ _build_advanced_chat_paused_blocking_response()
+ )
+
+ assert response["event"] == "workflow_paused"
+ assert response["workflow_run_id"] == "run-1"
+ assert response["answer"] == "partial"
+ assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
+ assert response["data"]["reasons"][0]["expiration_time"] == 100
+ assert "human_input_forms" not in response["data"]
+
+ def test_workflow_blocking_pause_payload_contract(self) -> None:
+ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
+
+ response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
+ _build_workflow_paused_blocking_response()
+ )
+
+ assert response["workflow_run_id"] == "r1"
+ assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
+ assert response["data"]["paused_nodes"] == ["node-1"]
+ assert response["data"]["reasons"] == [
+ {"type": "human_input_required", "form_id": "form-1", "expiration_time": 100}
+ ]
+ assert "human_input_forms" not in response["data"]
+
+ def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
+ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
+ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
+ from core.app.app_config.entities import AppAdditionalFeatures
+ from models.enums import MessageStatus
+ from models.model import EndUser
+
+ app_config = WorkflowUIBasedAppConfig(
+ tenant_id="tenant",
+ app_id="app",
+ app_mode=AppMode.ADVANCED_CHAT,
+ additional_features=AppAdditionalFeatures(),
+ variables=[],
+ workflow_id="workflow-id",
+ )
+ application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
+ task_id="task",
+ app_config=app_config,
+ inputs={},
+ query="hello",
+ files=[],
+ user_id="user",
+ stream=False,
+ invoke_from=InvokeFrom.WEB_APP,
+ extras={},
+ trace_manager=None,
+ workflow_run_id="run-id",
+ )
+ pipeline = AdvancedChatAppGenerateTaskPipeline(
+ application_generate_entity=application_generate_entity,
+ workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
+ queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
+ conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
+ message=SimpleNamespace(
+ id="message-id",
+ query="hello",
+ created_at=datetime.utcnow(),
+ status=MessageStatus.NORMAL,
+ answer="",
+ ),
+ user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
+ stream=False,
+ dialogue_count=1,
+ draft_var_saver_factory=lambda **kwargs: None,
+ )
+ pipeline._task_state.answer = "partial answer"
+ pipeline._workflow_run_id = "run-id"
+
+ def _gen():
+ yield HumanInputRequiredResponse(
+ task_id="task",
+ workflow_run_id="run-id",
+ data=HumanInputRequiredResponse.Data(
+ form_id="form-1",
+ node_id="node-1",
+ node_title="Approval",
+ form_content="Need approval",
+ inputs=[],
+ actions=[UserAction(id="approve", title="Approve")],
+ display_in_ui=True,
+ form_token="token-1",
+ resolved_default_values={},
+ expiration_time=123,
+ ),
+ )
+ yield WorkflowPauseStreamResponse(
+ task_id="task",
+ workflow_run_id="run-id",
+ data=WorkflowPauseStreamResponse.Data(
+ workflow_run_id="run-id",
+ paused_nodes=["node-1"],
+ outputs={},
+ reasons=[
+ {
+ "type": PauseReasonType.HUMAN_INPUT_REQUIRED,
+ "form_id": "form-1",
+ "node_id": "node-1",
+ "expiration_time": 123,
+ },
+ ],
+ status="paused",
+ created_at=1,
+ elapsed_time=0.1,
+ total_tokens=0,
+ total_steps=0,
+ ),
+ )
+
+ response = pipeline._to_blocking_response(_gen())
+
+ assert isinstance(response, ChatbotAppPausedBlockingResponse)
+ assert response.data.answer == "partial answer"
+ assert response.data.workflow_run_id == "run-id"
+ assert response.data.reasons[0]["form_id"] == "form-1"
+ assert response.data.reasons[0]["expiration_time"] == 123
+
+ def test_workflow_blocking_pipeline_pause_payload_contract(self) -> None:
+ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
+
+ app_config = WorkflowUIBasedAppConfig(
+ tenant_id="tenant",
+ app_id="app",
+ app_mode=AppMode.WORKFLOW,
+ additional_features=AppAdditionalFeatures(),
+ variables=[],
+ workflow_id="workflow-id",
+ )
+ application_generate_entity = WorkflowAppGenerateEntity.model_construct(
+ task_id="task",
+ app_config=app_config,
+ inputs={},
+ files=[],
+ user_id="user",
+ stream=False,
+ invoke_from=InvokeFrom.WEB_APP,
+ trace_manager=None,
+ workflow_execution_id="run-id",
+ extras={},
+ call_depth=0,
+ )
+ pipeline = WorkflowAppGenerateTaskPipeline(
+ application_generate_entity=application_generate_entity,
+ workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
+ queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
+ user=SimpleNamespace(id="user", session_id="session"),
+ stream=False,
+ draft_var_saver_factory=lambda **kwargs: None,
+ )
+
+ def _gen():
+ yield HumanInputRequiredResponse(
+ task_id="task",
+ workflow_run_id="run",
+ data=HumanInputRequiredResponse.Data(
+ form_id="form-1",
+ node_id="node-1",
+ node_title="Human Input",
+ form_content="content",
+ expiration_time=1,
+ ),
+ )
+ yield WorkflowPauseStreamResponse(
+ task_id="task",
+ workflow_run_id="run",
+ data=WorkflowPauseStreamResponse.Data(
+ workflow_run_id="run",
+ status=WorkflowExecutionStatus.PAUSED,
+ outputs={},
+ paused_nodes=["node-1"],
+ reasons=[{"type": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
+ created_at=1,
+ elapsed_time=0.1,
+ total_tokens=0,
+ total_steps=0,
+ ),
+ )
+
+ response = pipeline._to_blocking_response(_gen())
+
+ assert isinstance(response, WorkflowAppPausedBlockingResponse)
+ assert response.data.status == WorkflowExecutionStatus.PAUSED
+ assert response.data.paused_nodes == ["node-1"]
+ assert response.data.reasons == [
+ {"type": "human_input_required", "form_id": "form-1", "expiration_time": 1}
+ ]
+
+ def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
+ converter = _build_service_api_pause_converter()
+ converter.workflow_start_to_stream_response(
+ task_id="task",
+ workflow_run_id="run-id",
+ workflow_id="workflow-id",
+ reason=WorkflowStartReason.INITIAL,
+ )
+
+ expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
+
+ class _FakeSession:
+ def execute(self, _stmt):
+ return [("form-1", expiration_time, '{"display_in_ui": true}')]
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc, tb):
+ return False
+
+ monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
+ monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
+ monkeypatch.setattr(
+ workflow_response_converter,
+ "load_form_tokens_by_form_id",
+ lambda form_ids, session=None: {"form-1": "token"},
+ )
+
+ reason = HumanInputRequired(
+ form_id="form-1",
+ form_content="Rendered",
+ inputs=[
+ FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
+ ],
+ actions=[UserAction(id="approve", title="Approve")],
+ display_in_ui=True,
+ node_id="node-id",
+ node_title="Human Step",
+ form_token="token",
+ )
+ queue_event = QueueWorkflowPausedEvent(
+ reasons=[reason],
+ outputs={"answer": "value"},
+ paused_nodes=["node-id"],
+ )
+
+ runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
+ responses = converter.workflow_pause_to_stream_response(
+ event=queue_event,
+ task_id="task",
+ graph_runtime_state=runtime_state,
+ )
+
+ assert isinstance(responses[-1], WorkflowPauseStreamResponse)
+ pause_resp = responses[-1]
+ assert pause_resp.workflow_run_id == "run-id"
+ assert pause_resp.data.paused_nodes == ["node-id"]
+ assert pause_resp.data.outputs == {}
+ assert pause_resp.data.reasons[0]["type"] == "human_input_required"
+ assert pause_resp.data.reasons[0]["form_id"] == "form-1"
+ assert pause_resp.data.reasons[0]["form_token"] == "token"
+ assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
+
+ assert isinstance(responses[0], HumanInputRequiredResponse)
+ hi_resp = responses[0]
+ assert hi_resp.data.form_id == "form-1"
+ assert hi_resp.data.node_id == "node-id"
+ assert hi_resp.data.node_title == "Human Step"
+ assert hi_resp.data.inputs[0].output_variable_name == "field"
+ assert hi_resp.data.actions[0].id == "approve"
+ assert hi_resp.data.display_in_ui is True
+ assert hi_resp.data.form_token == "token"
+ assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
+
+ # Snapshot payload contract
+ def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
+ workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
+ snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
+ resumption_context = _build_resumption_context("task-ctx")
+ monkeypatch.setattr("services.workflow_event_snapshot_service.load_form_tokens_by_form_id", lambda form_ids, session=None: {"form-1": "wtok"})
+
+ class _SessionContext:
+ def __init__(self, session):
+ self._session = session
+
+ def __enter__(self):
+ return self._session
+
+ def __exit__(self, exc_type, exc, tb):
+ return False
+
+ session_maker = lambda: _SessionContext(
+ SimpleNamespace(
+ execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
+ )
+ )
+ pause_entity = _FakePauseEntity(
+ pause_id="pause-1",
+ workflow_run_id="run-1",
+ paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
+ pause_reasons=[
+ HumanInputRequired(
+ form_id="form-1",
+ form_content="content",
+ node_id="node-1",
+ node_title="Human Input",
+ form_token="wtok",
+ )
+ ],
+ )
+
+ events = _build_snapshot_events(
+ workflow_run=workflow_run,
+ node_snapshots=[snapshot],
+ task_id="task-ctx",
+ message_context=None,
+ pause_entity=pause_entity,
+ resumption_context=resumption_context,
+ session_maker=session_maker,
+ )
+
+ assert [event["event"] for event in events] == [
+ "workflow_started",
+ "node_started",
+ "node_finished",
+ "human_input_required",
+ "workflow_paused",
+ ]
+ assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
+ assert events[3]["data"]["form_token"] == "wtok"
+ pause_data = events[-1]["data"]
+ assert pause_data["paused_nodes"] == ["node-1"]
+ assert pause_data["outputs"] == {"result": "value"}
+ assert pause_data["reasons"][0]["type"] == "human_input_required"
+ assert pause_data["reasons"][0]["form_token"] == "wtok"
+ assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
+ assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
+ assert pause_data["elapsed_time"] == workflow_run.elapsed_time
+ assert pause_data["total_tokens"] == workflow_run.total_tokens
+ assert pause_data["total_steps"] == workflow_run.total_steps
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py
index 6ec33e4884..9bb544bca5 100644
--- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py
+++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py
@@ -6,7 +6,7 @@ import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
-from unittest.mock import Mock
+from unittest.mock import ANY, Mock
import pytest
from werkzeug.exceptions import NotFound
@@ -128,7 +128,11 @@ class TestWorkflowEventsApi:
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
- msg_generator.retrieve_events.assert_called_once_with(AppMode.WORKFLOW, "run-1")
+ msg_generator.retrieve_events.assert_called_once_with(
+ AppMode.WORKFLOW,
+ "run-1",
+ terminal_events=None,
+ )
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py
index e9fdeefee4..c5c3a00a82 100644
--- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py
@@ -1,10 +1,10 @@
from collections.abc import Generator
-from graphon.enums import WorkflowNodeExecutionStatus
-
+import pytest
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
+ ChatbotAppPausedBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
@@ -12,6 +12,8 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
)
+from graphon.entities.pause_reason import PauseReasonType
+from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
class TestAdvancedChatGenerateResponseConverter:
@@ -29,6 +31,37 @@ class TestAdvancedChatGenerateResponseConverter:
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
+ def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch):
+ data = ChatbotAppPausedBlockingResponse.Data(
+ id="msg-1",
+ mode="chat",
+ conversation_id="c1",
+ message_id="m1",
+ workflow_run_id="run-1",
+ answer="partial",
+ metadata={"usage": {"total_tokens": 1}},
+ created_at=1,
+ paused_nodes=["node-1"],
+ reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}],
+ status=WorkflowExecutionStatus.PAUSED,
+ elapsed_time=0.1,
+ total_tokens=0,
+ total_steps=0,
+ )
+ original_model_dump = type(data).model_dump
+
+ def _model_dump_with_future_field(self, *args, **kwargs):
+ payload = original_model_dump(self, *args, **kwargs)
+ payload["future_field"] = "future-value"
+ return payload
+
+ monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field)
+ blocking = ChatbotAppPausedBlockingResponse(task_id="t1", data=data)
+
+ response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
+
+ assert response["data"]["future_field"] == "future-value"
+
def test_stream_simple_response_includes_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",
diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py
index 82b2e51019..a03e7b3c2b 100644
--- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py
+++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py
@@ -41,15 +41,20 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
+ ChatbotAppPausedBlockingResponse,
+ HumanInputRequiredResponse,
AnnotationReply,
AnnotationReplyAccount,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
+ WorkflowPauseStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from core.workflow.system_variables import build_system_variables
from libs.datetime_utils import naive_utc_now
+from graphon.entities.pause_reason import PauseReasonType
+from graphon.nodes.human_input.entities import UserAction
from models.enums import MessageStatus
from models.model import AppMode, EndUser
from tests.workflow_test_utils import build_test_variable_pool
@@ -123,6 +128,57 @@ class TestAdvancedChatGenerateTaskPipeline:
assert response.data.answer == "done"
assert response.data.metadata == {"k": "v"}
+ def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
+ pipeline = _make_pipeline()
+ pipeline._task_state.answer = "partial answer"
+ pipeline._workflow_run_id = "run-id"
+ pipeline._graph_runtime_state = GraphRuntimeState(
+ variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
+ start_at=0.0,
+ total_tokens=7,
+ node_run_steps=3,
+ )
+
+ def _gen():
+ yield HumanInputRequiredResponse(
+ task_id="task",
+ workflow_run_id="run-id",
+ data=HumanInputRequiredResponse.Data(
+ form_id="form-1",
+ node_id="node-1",
+ node_title="Approval",
+ form_content="Need approval",
+ inputs=[],
+ actions=[UserAction(id="approve", title="Approve")],
+ display_in_ui=True,
+ form_token="token-1",
+ resolved_default_values={},
+ expiration_time=123,
+ ),
+ )
+
+ response = pipeline._to_blocking_response(_gen())
+
+ assert isinstance(response, ChatbotAppPausedBlockingResponse)
+ assert response.data.workflow_run_id == "run-id"
+ assert response.data.status == "paused"
+ assert response.data.paused_nodes == ["node-1"]
+ assert response.data.reasons == [
+ {
+ "type": PauseReasonType.HUMAN_INPUT_REQUIRED,
+ "form_id": "form-1",
+ "node_id": "node-1",
+ "node_title": "Approval",
+ "form_content": "Need approval",
+ "inputs": [],
+ "actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
+ "display_in_ui": True,
+ "form_token": "token-1",
+ "resolved_default_values": {},
+ "expiration_time": 123,
+ }
+ ]
+
def test_handle_text_chunk_event_updates_state(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager = SimpleNamespace(
diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py
new file mode 100644
index 0000000000..560652f8cb
--- /dev/null
+++ b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py
@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+from collections.abc import Generator
+
+from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.entities.task_entities import (
+ AppStreamResponse,
+ PingStreamResponse,
+ WorkflowAppBlockingResponse,
+ WorkflowAppStreamResponse,
+)
+from graphon.enums import WorkflowExecutionStatus
+
+
+class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
+ blocking_full_calls: list[WorkflowAppBlockingResponse] = []
+ blocking_simple_calls: list[WorkflowAppBlockingResponse] = []
+ stream_full_calls: list[Generator[AppStreamResponse, None, None]] = []
+ stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = []
+
+ @classmethod
+ def reset(cls) -> None:
+ cls.blocking_full_calls = []
+ cls.blocking_simple_calls = []
+ cls.stream_full_calls = []
+ cls.stream_simple_calls = []
+
+ @classmethod
+ def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
+ cls.blocking_full_calls.append(blocking_response)
+ return {"kind": "blocking-full", "task_id": blocking_response.task_id}
+
+ @classmethod
+ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
+ cls.blocking_simple_calls.append(blocking_response)
+ return {"kind": "blocking-simple", "task_id": blocking_response.task_id}
+
+ @classmethod
+ def convert_stream_full_response(
+ cls, stream_response: Generator[AppStreamResponse, None, None]
+ ) -> Generator[dict | str, None, None]:
+ cls.stream_full_calls.append(stream_response)
+ yield {"kind": "stream-full"}
+
+ @classmethod
+ def convert_stream_simple_response(
+ cls, stream_response: Generator[AppStreamResponse, None, None]
+ ) -> Generator[dict | str, None, None]:
+ cls.stream_simple_calls.append(stream_response)
+ yield {"kind": "stream-simple"}
+
+
+def _build_blocking_response() -> WorkflowAppBlockingResponse:
+ return WorkflowAppBlockingResponse(
+ task_id="task-1",
+ workflow_run_id="run-1",
+ data=WorkflowAppBlockingResponse.Data(
+ id="run-1",
+ workflow_id="workflow-1",
+ status=WorkflowExecutionStatus.SUCCEEDED,
+ outputs={"ok": True},
+ error=None,
+ elapsed_time=0.1,
+ total_tokens=0,
+ total_steps=1,
+ created_at=1,
+ finished_at=2,
+ ),
+ )
+
+
+def _build_stream_response() -> Generator[AppStreamResponse, None, None]:
+ yield WorkflowAppStreamResponse(
+ workflow_run_id="run-1",
+ stream_response=PingStreamResponse(task_id="task-1"),
+ )
+
+
+def test_convert_routes_blocking_response_by_invoke_from() -> None:
+ _DummyConverter.reset()
+ blocking_response = _build_blocking_response()
+
+ full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API)
+ simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP)
+
+ assert full_result == {"kind": "blocking-full", "task_id": "task-1"}
+ assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"}
+ assert _DummyConverter.blocking_full_calls == [blocking_response]
+ assert _DummyConverter.blocking_simple_calls == [blocking_response]
+
+
+def test_convert_routes_stream_response_by_invoke_from() -> None:
+ _DummyConverter.reset()
+
+ full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API))
+ simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP))
+
+ assert full_result == [{"kind": "stream-full"}]
+ assert simple_result == [{"kind": "stream-simple"}]
+ assert len(_DummyConverter.stream_full_calls) == 1
+ assert len(_DummyConverter.stream_simple_calls) == 1
diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py
index 25377e633e..575bcca4bc 100644
--- a/api/tests/unit_tests/core/app/apps/test_message_generator.py
+++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py
@@ -1,5 +1,6 @@
from unittest.mock import Mock, patch
+from core.app.entities.task_entities import StreamEvent
from core.app.apps.message_generator import MessageGenerator
from models.model import AppMode
@@ -23,7 +24,21 @@ class TestMessageGenerator:
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
) as mock_stream,
):
- events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
+ events = list(
+ MessageGenerator.retrieve_events(
+ AppMode.WORKFLOW,
+ "run-1",
+ idle_timeout=1,
+ ping_interval=2,
+ terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
+ )
+ )
assert events == [{"event": "ping"}]
- mock_stream.assert_called_once()
+ mock_stream.assert_called_once_with(
+ topic="topic",
+ idle_timeout=1,
+ ping_interval=2,
+ on_subscribe=None,
+ terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
+ )
diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py
index a7714c56ce..4c613f120d 100644
--- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py
+++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py
@@ -106,3 +106,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
assert next(generator) == StreamEvent.PING.value
# next receive yields None -> ping interval triggers
assert next(generator) == StreamEvent.PING.value
+
+
+def test_stream_topic_events_can_continue_past_pause():
+ topic = FakeTopic()
+ topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode())
+ topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode())
+
+ generator = stream_topic_events(
+ topic=topic,
+ idle_timeout=1.0,
+ terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
+ )
+
+ assert next(generator) == StreamEvent.PING.value
+ assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value
+ assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value
+ with pytest.raises(StopIteration):
+ next(generator)
diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
index 8a717e1dcc..5259052c8f 100644
--- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
+++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
@@ -90,7 +90,6 @@ def test_graph_run_paused_event_emits_queue_pause_event():
assert queue_event.outputs == {"foo": "bar"}
assert queue_event.paused_nodes == ["node-pause-1"]
-
def _build_converter():
application_generate_entity = SimpleNamespace(
inputs={},
diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py
index b768e813bd..e0d2ae51aa 100644
--- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py
@@ -9,6 +9,7 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
+ WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py
index d91bb85aee..2beb327b66 100644
--- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py
+++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py
@@ -38,10 +38,12 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
+ HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
WorkflowFinishStreamResponse,
+ WorkflowAppPausedBlockingResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
@@ -91,27 +93,49 @@ def _make_pipeline():
class TestWorkflowGenerateTaskPipeline:
- def test_to_blocking_response_handles_pause(self):
+ def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
+ pipeline._graph_runtime_state = GraphRuntimeState(
+ variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
+ start_at=0.0,
+ total_tokens=5,
+ node_run_steps=2,
+ )
def _gen():
- yield WorkflowPauseStreamResponse(
+ yield HumanInputRequiredResponse(
task_id="task",
- workflow_run_id="run",
- data=WorkflowPauseStreamResponse.Data(
- workflow_run_id="run",
- status=WorkflowExecutionStatus.PAUSED,
- outputs={},
- created_at=1,
- elapsed_time=0.1,
- total_tokens=0,
- total_steps=0,
+ workflow_run_id="run-id",
+ data=HumanInputRequiredResponse.Data(
+ form_id="form-1",
+ node_id="node-1",
+ node_title="Human Input",
+ form_content="content",
+ expiration_time=1,
),
)
response = pipeline._to_blocking_response(_gen())
+ assert isinstance(response, WorkflowAppPausedBlockingResponse)
+ assert response.workflow_run_id == "run-id"
assert response.data.status == WorkflowExecutionStatus.PAUSED
+ assert response.data.paused_nodes == ["node-1"]
+ assert response.data.reasons == [
+ {
+ "type": "human_input_required",
+ "form_id": "form-1",
+ "node_id": "node-1",
+ "node_title": "Human Input",
+ "form_content": "content",
+ "inputs": [],
+ "actions": [],
+ "display_in_ui": False,
+ "form_token": None,
+ "resolved_default_values": {},
+ "expiration_time": 1,
+ }
+ ]
def test_to_blocking_response_handles_finish(self):
pipeline = _make_pipeline()
diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
new file mode 100644
index 0000000000..c4cbfd228b
--- /dev/null
+++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+from datetime import UTC, datetime
+from types import SimpleNamespace
+
+from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction
+from graphon.nodes.human_input.enums import FormInputType
+from models.human_input import RecipientType
+from repositories.sqlalchemy_api_workflow_run_repository import _build_human_input_required_reason
+
+
+def _build_form_model() -> SimpleNamespace:
+ expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
+ definition = FormDefinition(
+ form_content="content",
+ inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
+ user_actions=[UserAction(id="approve", title="Approve")],
+ rendered_content="rendered",
+ expiration_time=expiration_time,
+ default_values={"name": "Alice"},
+ node_title="Ask Name",
+ display_in_ui=True,
+ )
+ return SimpleNamespace(
+ id="form-1",
+ node_id="node-1",
+ form_definition=definition.model_dump_json(),
+ expiration_time=expiration_time,
+ )
+
+
+def _build_reason_model() -> SimpleNamespace:
+ return SimpleNamespace(form_id="form-1", node_id="node-1")
+
+
+def test_build_human_input_required_reason_prefers_standalone_web_app_token() -> None:
+ reason = _build_human_input_required_reason(
+ _build_reason_model(),
+ _build_form_model(),
+ [
+ SimpleNamespace(recipient_type=RecipientType.BACKSTAGE, access_token="btok"),
+ SimpleNamespace(recipient_type=RecipientType.CONSOLE, access_token="ctok"),
+ SimpleNamespace(recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"),
+ ],
+ )
+
+ assert reason.node_title == "Ask Name"
+ assert reason.resolved_default_values == {"name": "Alice"}
+
+
+def test_build_human_input_required_reason_falls_back_to_console_token() -> None:
+ reason = _build_human_input_required_reason(
+ _build_reason_model(),
+ _build_form_model(),
+ [
+ SimpleNamespace(recipient_type=RecipientType.BACKSTAGE, access_token="btok"),
+ SimpleNamespace(recipient_type=RecipientType.CONSOLE, access_token="ctok"),
+ ],
+ )
+
+ assert reason.node_id == "node-1"
+ assert reason.actions[0].id == "approve"
diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py
index c2b430c551..119a7adc45 100644
--- a/api/tests/unit_tests/services/test_app_generate_service.py
+++ b/api/tests/unit_tests/services/test_app_generate_service.py
@@ -327,7 +327,8 @@ class TestGenerate:
streaming=False,
)
assert result == {"result": "advanced-blocking"}
- assert gen_spy.call_args.kwargs.get("streaming") is False
+ call_kwargs = gen_spy.call_args.kwargs
+ assert call_kwargs.get("streaming") is False
retrieve_spy.assert_not_called()
# -- ADVANCED_CHAT streaming --------------------------------------------
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py
index 4146fd312b..a2fa5a4575 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py
@@ -1,28 +1,36 @@
import json
import queue
-from collections.abc import Sequence
+from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import MagicMock
import pytest
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
+from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
+from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
+from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import (
BufferState,
MessageContext,
_build_snapshot_events,
+ _is_terminal_event,
_resolve_task_id,
+ build_workflow_event_stream,
)
@@ -125,49 +133,6 @@ def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
)
-def test_build_snapshot_events_includes_pause_event() -> None:
- workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
- snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
- resumption_context = _build_resumption_context("task-ctx")
- pause_entity = _FakePauseEntity(
- pause_id="pause-1",
- workflow_run_id="run-1",
- paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
- pause_reasons=[
- HumanInputRequired(
- form_id="form-1",
- form_content="content",
- node_id="node-1",
- node_title="Human Input",
- )
- ],
- )
-
- events = _build_snapshot_events(
- workflow_run=workflow_run,
- node_snapshots=[snapshot],
- task_id="task-ctx",
- message_context=None,
- pause_entity=pause_entity,
- resumption_context=resumption_context,
- )
-
- assert [event["event"] for event in events] == [
- "workflow_started",
- "node_started",
- "node_finished",
- "workflow_paused",
- ]
- assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
- pause_data = events[-1]["data"]
- assert pause_data["paused_nodes"] == ["node-1"]
- assert pause_data["outputs"] == {"result": "value"}
- assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
- assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
- assert pause_data["elapsed_time"] == workflow_run.elapsed_time
- assert pause_data["total_tokens"] == workflow_run.total_tokens
- assert pause_data["total_steps"] == workflow_run.total_steps
-
def test_build_snapshot_events_applies_message_context() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
@@ -222,3 +187,647 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected
+
+
+def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
+ return WorkflowRun(
+ id="run-1",
+ tenant_id="tenant-1",
+ app_id="app-1",
+ workflow_id="workflow-1",
+ type="workflow",
+ triggered_from="app-run",
+ version="v1",
+ graph=None,
+ inputs=json.dumps({"query": "hello"}),
+ status=status,
+ outputs=json.dumps({}),
+ error=None,
+ elapsed_time=1.2,
+ total_tokens=5,
+ total_steps=2,
+ created_by_role=CreatorUserRole.END_USER,
+ created_by="user-1",
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
+ )
+
+
+def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
+ app_config = WorkflowUIBasedAppConfig(
+ tenant_id="tenant-1",
+ app_id="app-1",
+ app_mode=AppMode.WORKFLOW,
+ workflow_id="workflow-1",
+ )
+ generate_entity = WorkflowAppGenerateEntity(
+ task_id=task_id,
+ app_config=app_config,
+ inputs={},
+ files=[],
+ user_id="user-1",
+ stream=True,
+ invoke_from=InvokeFrom.EXPLORE,
+ call_depth=0,
+ workflow_execution_id="run-1",
+ )
+ runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
+ runtime_state.outputs = {"answer": "ok"}
+ wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
+ return WorkflowResumptionContext(
+ generate_entity=wrapper,
+ serialized_graph_runtime_state=runtime_state.dumps(),
+ )
+
+
+class _SessionContext:
+ def __init__(self, session: Any) -> None:
+ self._session = session
+
+ def __enter__(self) -> Any:
+ return self._session
+
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
+ return False
+
+
+class _SessionMaker:
+ def __init__(self, session: Any) -> None:
+ self._session = session
+
+ def __call__(self) -> _SessionContext:
+ return _SessionContext(self._session)
+
+
+class _SubscriptionContext:
+ def __init__(self, subscription: Any) -> None:
+ self._subscription = subscription
+
+ def __enter__(self) -> Any:
+ return self._subscription
+
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
+ return False
+
+
+class _Topic:
+ def __init__(self, subscription: Any) -> None:
+ self._subscription = subscription
+
+ def subscribe(self) -> _SubscriptionContext:
+ return _SubscriptionContext(self._subscription)
+
+
+class _StaticSubscription:
+ def receive(self, timeout: int = 1) -> None:
+ return None
+
+
+@dataclass(frozen=True)
+class _PauseEntity(WorkflowPauseEntity):
+ state: bytes
+
+ @property
+ def id(self) -> str:
+ return "pause-1"
+
+ @property
+ def workflow_execution_id(self) -> str:
+ return "run-1"
+
+ @property
+ def resumed_at(self) -> datetime | None:
+ return None
+
+ @property
+ def paused_at(self) -> datetime:
+ return datetime(2024, 1, 1, tzinfo=UTC)
+
+ def get_state(self) -> bytes:
+ return self.state
+
+ def get_pause_reasons(self) -> list[Any]:
+ return []
+
+
+def test_get_message_context_should_return_none_when_no_message() -> None:
+ # Arrange
+ session = SimpleNamespace(scalar=MagicMock(return_value=None))
+ session_maker = _SessionMaker(session)
+
+ # Act
+ result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
+
+ # Assert
+ assert result is None
+
+
+def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
+ # Arrange
+ message = SimpleNamespace(
+ id="msg-1",
+ conversation_id="conv-1",
+ created_at=None,
+ answer="answer",
+ )
+ session = SimpleNamespace(scalar=MagicMock(return_value=message))
+ session_maker = _SessionMaker(session)
+
+ # Act
+ result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
+
+ # Assert
+ assert result is not None
+ assert result.created_at == 0
+ assert result.message_id == "msg-1"
+ assert result.conversation_id == "conv-1"
+ assert result.answer == "answer"
+
+
+def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
+ # Arrange
+
+ # Act
+ result = service_module._load_resumption_context(None)
+
+ # Assert
+ assert result is None
+
+
+def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
+ # Arrange
+ pause_entity = _PauseEntity(state=b"not-a-valid-state")
+
+ # Act
+ result = service_module._load_resumption_context(pause_entity)
+
+ # Assert
+ assert result is None
+
+
+def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
+ # Arrange
+ context = _build_resumption_context_additional(task_id="task-ctx")
+ pause_entity = _PauseEntity(state=context.dumps().encode())
+
+ # Act
+ result = service_module._load_resumption_context(pause_entity)
+
+ # Assert
+ assert result is not None
+ assert result.get_generate_entity().task_id == "task-ctx"
+
+
+def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
+ # Arrange
+
+ # Act
+ result = service_module._resolve_task_id(
+ resumption_context=None,
+ buffer_state=None,
+ workflow_run_id="run-1",
+ )
+
+ # Assert
+ assert result == "run-1"
+
+
+@pytest.mark.parametrize(
+ ("payload", "expected"),
+ [
+ (b'{"event":"node_started"}', {"event": "node_started"}),
+ (b"invalid-json", None),
+ (b"[]", None),
+ ],
+)
+def test_parse_event_message_should_parse_only_json_object(
+ payload: bytes,
+ expected: dict[str, Any] | None,
+) -> None:
+ # Arrange
+
+ # Act
+ result = service_module._parse_event_message(payload)
+
+ # Assert
+ assert result == expected
+
+
+def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
+ # Arrange
+ finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
+ paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
+
+ # Act
+ is_finished = service_module._is_terminal_event(finished_event, close_on_pause=False)
+ paused_without_flag = service_module._is_terminal_event(paused_event, close_on_pause=False)
+ paused_with_flag = service_module._is_terminal_event(paused_event, close_on_pause=True)
+
+ # Assert
+ assert is_finished is True
+ assert paused_without_flag is False
+ assert paused_with_flag is True
+ assert service_module._is_terminal_event(StreamEvent.PING.value, close_on_pause=True) is False
+
+
+def test_apply_message_context_should_update_payload_when_context_exists() -> None:
+ # Arrange
+ payload: dict[str, Any] = {"event": "workflow_started"}
+ context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
+
+ # Act
+ service_module._apply_message_context(payload, context)
+
+ # Assert
+ assert payload["conversation_id"] == "conv-1"
+ assert payload["message_id"] == "msg-1"
+ assert payload["created_at"] == 1700000000
+
+
+def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
+ # Arrange
+ class Subscription:
+ def __init__(self) -> None:
+ self._calls = 0
+
+ def receive(self, timeout: int = 1) -> bytes | None:
+ self._calls += 1
+ if self._calls == 1:
+ return b'{"event":"node_started","task_id":"task-1"}'
+ return None
+
+ subscription = Subscription()
+
+ # Act
+ buffer_state = service_module._start_buffering(subscription)
+ ready = buffer_state.task_id_ready.wait(timeout=1)
+ event = buffer_state.queue.get(timeout=1)
+ buffer_state.stop_event.set()
+ finished = buffer_state.done_event.wait(timeout=1)
+
+ # Assert
+ assert ready is True
+ assert finished is True
+ assert buffer_state.task_id_hint == "task-1"
+ assert event["event"] == "node_started"
+
+
+def test_start_buffering_should_drop_old_event_when_queue_is_full(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ # Arrange
+ class QueueWithSingleFull:
+ def __init__(self) -> None:
+ self._first_put = True
+ self.items: list[dict[str, Any]] = [{"event": "old"}]
+
+ def put_nowait(self, item: dict[str, Any]) -> None:
+ if self._first_put:
+ self._first_put = False
+ raise queue.Full
+ self.items.append(item)
+
+ def get_nowait(self) -> dict[str, Any]:
+ if not self.items:
+ raise queue.Empty
+ return self.items.pop(0)
+
+ def empty(self) -> bool:
+ return len(self.items) == 0
+
+ fake_queue = QueueWithSingleFull()
+ monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
+
+ class Subscription:
+ def __init__(self) -> None:
+ self._calls = 0
+
+ def receive(self, timeout: int = 1) -> bytes | None:
+ self._calls += 1
+ if self._calls == 1:
+ return b'{"event":"node_started","task_id":"task-2"}'
+ return None
+
+ subscription = Subscription()
+
+ # Act
+ buffer_state = service_module._start_buffering(subscription)
+ ready = buffer_state.task_id_ready.wait(timeout=1)
+ buffer_state.stop_event.set()
+ finished = buffer_state.done_event.wait(timeout=1)
+
+ # Assert
+ assert ready is True
+ assert finished is True
+ assert fake_queue.items[-1]["task_id"] == "task-2"
+
+
+def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
+ # Arrange
+ class Subscription:
+ def receive(self, timeout: int = 1) -> bytes | None:
+ raise RuntimeError("subscription failure")
+
+ subscription = Subscription()
+
+ # Act
+ buffer_state = service_module._start_buffering(subscription)
+ finished = buffer_state.done_event.wait(timeout=1)
+
+ # Assert
+ assert finished is True
+
+
+def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ # Arrange
+ workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
+ topic = _Topic(_StaticSubscription())
+ workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
+ node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
+ factory = SimpleNamespace(
+ create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
+ create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
+ )
+ monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
+ monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
+ monkeypatch.setattr(
+ service_module,
+ "_get_message_context",
+ MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
+ )
+ monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
+ buffer_state = BufferState(
+ queue=queue.Queue(),
+ stop_event=Event(),
+ done_event=Event(),
+ task_id_ready=Event(),
+ task_id_hint="task-1",
+ )
+ monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
+ monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
+ monkeypatch.setattr(
+ service_module,
+ "_build_snapshot_events",
+ MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
+ )
+
+ # Act
+ events = list(
+ build_workflow_event_stream(
+ app_mode=AppMode.ADVANCED_CHAT,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=MagicMock(),
+ )
+ )
+
+ # Assert
+ assert events[0] == StreamEvent.PING.value
+ finished_event = cast(Mapping[str, Any], events[1])
+ assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
+ assert buffer_state.stop_event.is_set() is True
+ node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
+ called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
+ assert called_kwargs["workflow_run_id"] == "run-1"
+
+
+def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ # Arrange
+ workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
+ topic = _Topic(_StaticSubscription())
+ workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
+ node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
+ factory = SimpleNamespace(
+ create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
+ create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
+ )
+ monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
+ monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
+ monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
+ monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
+ monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
+
+ class AlwaysEmptyQueue:
+ def empty(self) -> bool:
+ return False
+
+ def get(self, timeout: int = 1) -> None:
+ raise queue.Empty
+
+ buffer_state = BufferState(
+ queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
+ stop_event=Event(),
+ done_event=Event(),
+ task_id_ready=Event(),
+ task_id_hint="task-1",
+ )
+ monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
+ time_values = cycle([0.0, 6.0, 21.0, 26.0])
+ monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
+
+ # Act
+ events = list(
+ build_workflow_event_stream(
+ app_mode=AppMode.WORKFLOW,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=MagicMock(),
+ idle_timeout=20.0,
+ ping_interval=5.0,
+ )
+ )
+
+ # Assert
+ assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
+ assert buffer_state.stop_event.is_set() is True
+
+
+def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ # Arrange
+ workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
+ topic = _Topic(_StaticSubscription())
+ workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
+ node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
+ factory = SimpleNamespace(
+ create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
+ create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
+ )
+ monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
+ monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
+ monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
+ monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
+ monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
+ buffer_state = BufferState(
+ queue=queue.Queue(),
+ stop_event=Event(),
+ done_event=Event(),
+ task_id_ready=Event(),
+ task_id_hint="task-1",
+ )
+ buffer_state.done_event.set()
+ monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
+
+ # Act
+ events = list(
+ build_workflow_event_stream(
+ app_mode=AppMode.WORKFLOW,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=MagicMock(),
+ )
+ )
+
+ # Assert
+ assert events == [StreamEvent.PING.value]
+ assert buffer_state.stop_event.is_set() is True
+
+
+def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ # Arrange
+ workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
+ topic = _Topic(_StaticSubscription())
+ workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
+ node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
+ factory = SimpleNamespace(
+ create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
+ create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
+ )
+ monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
+ monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
+ monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
+ monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
+ snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
+ monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
+ buffer_state = BufferState(
+ queue=queue.Queue(),
+ stop_event=Event(),
+ done_event=Event(),
+ task_id_ready=Event(),
+ task_id_hint="task-1",
+ )
+ monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
+
+ # Act
+ events = list(
+ build_workflow_event_stream(
+ app_mode=AppMode.WORKFLOW,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=MagicMock(),
+ )
+ )
+
+ # Assert
+ assert events[0] == StreamEvent.PING.value
+ assert snapshot_builder.call_args.kwargs["pause_entity"] is None
+
+
+def test_is_terminal_event_respects_close_on_pause_flag() -> None:
+ pause_event = {"event": "workflow_paused"}
+ finish_event = {"event": "workflow_finished"}
+
+ assert _is_terminal_event(pause_event, close_on_pause=True) is True
+ assert _is_terminal_event(pause_event, close_on_pause=False) is False
+ assert _is_terminal_event(finish_event, close_on_pause=False) is True
+
+
+def test_build_snapshot_events_preserves_public_form_token(monkeypatch: pytest.MonkeyPatch) -> None:
+ workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
+ snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
+ resumption_context = _build_resumption_context("task-ctx")
+ monkeypatch.setattr(service_module, "load_form_tokens_by_form_id", lambda form_ids, session=None: {"form-1": "wtok"})
+ session_maker = _SessionMaker(
+ SimpleNamespace(
+ execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
+ )
+ )
+ pause_entity = _FakePauseEntity(
+ pause_id="pause-1",
+ workflow_run_id="run-1",
+ paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
+ pause_reasons=[
+ HumanInputRequired(
+ form_id="form-1",
+ form_content="content",
+ node_id="node-1",
+ node_title="Human Input",
+ form_token="wtok",
+ )
+ ],
+ )
+
+ events = _build_snapshot_events(
+ workflow_run=workflow_run,
+ node_snapshots=[snapshot],
+ task_id="task-ctx",
+ message_context=None,
+ pause_entity=pause_entity,
+ resumption_context=resumption_context,
+ session_maker=cast(sessionmaker[Session], session_maker),
+ )
+
+ assert events[-2]["event"] == StreamEvent.HUMAN_INPUT_REQUIRED.value
+ assert events[-2]["data"]["form_token"] == "wtok"
+ pause_data = events[-1]["data"]
+ assert pause_data["reasons"][0]["form_token"] == "wtok"
+
+
+def test_build_workflow_event_stream_loads_pause_tokens_without_flask_app_context(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
+ topic = _Topic(_StaticSubscription())
+ pause_entity = _FakePauseEntity(
+ pause_id="pause-1",
+ workflow_run_id="run-1",
+ paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
+ pause_reasons=[
+ HumanInputRequired(
+ form_id="form-1",
+ form_content="content",
+ node_id="node-1",
+ node_title="Human Input",
+ )
+ ],
+ )
+ workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(return_value=pause_entity))
+ node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
+ factory = SimpleNamespace(
+ create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
+ create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
+ )
+ monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
+ monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
+ monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=_build_resumption_context("task-1")))
+ monkeypatch.setattr(service_module, "load_form_tokens_by_form_id", lambda form_ids, session=None: {"form-1": "wtok"})
+
+ session = SimpleNamespace(
+ scalar=MagicMock(return_value=None),
+ execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
+ )
+ session_maker = _SessionMaker(session)
+
+ events = list(
+ build_workflow_event_stream(
+ app_mode=AppMode.WORKFLOW,
+ workflow_run=workflow_run,
+ tenant_id="tenant-1",
+ app_id="app-1",
+ session_maker=cast(sessionmaker[Session], session_maker),
+ )
+ )
+
+ pause_event = cast(Mapping[str, Any], events[-1])
+ assert pause_event["event"] == StreamEvent.WORKFLOW_PAUSED.value
+ assert pause_event["data"]["reasons"][0]["form_token"] == "wtok"
diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py
index d3cf632b47..0cedc387c0 100644
--- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py
+++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py
@@ -7,11 +7,16 @@ from unittest.mock import MagicMock
import pytest
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from models.enums import CreatorUserRole
from models.model import App, AppMode, Conversation
from models.workflow import Workflow, WorkflowRun
-from tasks.app_generate.workflow_execute_task import _publish_streaming_response, _resume_app_execution
+from tasks.app_generate.workflow_execute_task import (
+ _publish_streaming_response,
+ _resume_advanced_chat,
+ _resume_app_execution,
+ _resume_workflow,
+)
class _FakeSessionContext:
@@ -38,12 +43,28 @@ def _build_advanced_chat_generate_entity(conversation_id: str | None) -> Advance
)
+def _build_workflow_generate_entity(stream: bool) -> WorkflowAppGenerateEntity:
+ return WorkflowAppGenerateEntity(
+ task_id="task-id",
+ inputs={},
+ files=[],
+ user_id="user-id",
+ stream=stream,
+ invoke_from=InvokeFrom.WEB_APP,
+ workflow_execution_id="workflow-run-id",
+ )
+
+
+def _single_event_generator(payload):
+ yield payload
+
+
@pytest.fixture
-def mock_topic(mocker) -> MagicMock:
+def mock_topic(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
topic = MagicMock()
- mocker.patch(
+ monkeypatch.setattr(
"tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic",
- return_value=topic,
+ lambda *_args, **_kwargs: topic,
)
return topic
@@ -67,31 +88,35 @@ def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())
-def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(mocker):
+def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(monkeypatch: pytest.MonkeyPatch):
workflow_run_id = "run-id"
conversation_id = "conversation-id"
message = MagicMock()
- mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object()))
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object()))
pause_entity = MagicMock()
pause_entity.get_state.return_value = b"state"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_pause.return_value = pause_entity
- mocker.patch(
+ monkeypatch.setattr(
"tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
- return_value=workflow_run_repo,
+ lambda *_args, **_kwargs: workflow_run_repo,
)
generate_entity = _build_advanced_chat_generate_entity(conversation_id)
resumption_context = MagicMock()
resumption_context.serialized_graph_runtime_state = "{}"
resumption_context.get_generate_entity.return_value = generate_entity
- mocker.patch(
- "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads",
+ lambda *_args, **_kwargs: resumption_context,
+ )
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot",
+ lambda *_args, **_kwargs: MagicMock(),
)
- mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock())
workflow_run = SimpleNamespace(
workflow_id="wf-id",
@@ -120,10 +145,11 @@ def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(m
session.get.side_effect = _session_get
session.scalar.return_value = message
- mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session))
- mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock())
- resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat")
- mocker.patch("tasks.app_generate.workflow_execute_task._resume_workflow")
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task.Session", lambda *_args, **_kwargs: _FakeSessionContext(session))
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resolve_user_for_run", lambda *_args, **_kwargs: MagicMock())
+ resume_advanced_chat = MagicMock()
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_advanced_chat", resume_advanced_chat)
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_workflow", MagicMock())
_resume_app_execution({"workflow_run_id": workflow_run_id})
@@ -144,29 +170,35 @@ def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(m
assert resume_advanced_chat.call_args.kwargs["message"] is message
-def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id(mocker):
+def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id(
+ monkeypatch: pytest.MonkeyPatch,
+):
workflow_run_id = "run-id"
- mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object()))
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object()))
pause_entity = MagicMock()
pause_entity.get_state.return_value = b"state"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_pause.return_value = pause_entity
- mocker.patch(
+ monkeypatch.setattr(
"tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
- return_value=workflow_run_repo,
+ lambda *_args, **_kwargs: workflow_run_repo,
)
generate_entity = _build_advanced_chat_generate_entity(conversation_id=None)
resumption_context = MagicMock()
resumption_context.serialized_graph_runtime_state = "{}"
resumption_context.get_generate_entity.return_value = generate_entity
- mocker.patch(
- "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads",
+ lambda *_args, **_kwargs: resumption_context,
+ )
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot",
+ lambda *_args, **_kwargs: MagicMock(),
)
- mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock())
workflow_run = SimpleNamespace(
workflow_id="wf-id",
@@ -191,12 +223,99 @@ def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversat
session.get.side_effect = _session_get
- mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session))
- mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock())
- resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat")
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task.Session", lambda *_args, **_kwargs: _FakeSessionContext(session))
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resolve_user_for_run", lambda *_args, **_kwargs: MagicMock())
+ resume_advanced_chat = MagicMock()
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._resume_advanced_chat", resume_advanced_chat)
_resume_app_execution({"workflow_run_id": workflow_run_id})
session.scalar.assert_not_called()
workflow_run_repo.resume_workflow_pause.assert_not_called()
resume_advanced_chat.assert_not_called()
+
+
+def test_resume_advanced_chat_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch):
+ generate_entity = _build_advanced_chat_generate_entity(conversation_id="conversation-id")
+ generate_entity.stream = False
+
+ generator_instance = MagicMock()
+ response_stream = _single_event_generator({"event": "message"})
+ generator_instance.resume.return_value = response_stream
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.AdvancedChatAppGenerator",
+ lambda: generator_instance,
+ )
+
+ publish_streaming_response = MagicMock()
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._publish_streaming_response", publish_streaming_response)
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_execution_repository",
+ lambda **kwargs: MagicMock(),
+ )
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ lambda **kwargs: MagicMock(),
+ )
+
+ _resume_advanced_chat(
+ app_model=SimpleNamespace(id="app-id"),
+ workflow=SimpleNamespace(created_by="workflow-owner"),
+ user=MagicMock(),
+ conversation=SimpleNamespace(id="conversation-id"),
+ message=MagicMock(),
+ generate_entity=generate_entity,
+ graph_runtime_state=MagicMock(),
+ session_factory=MagicMock(),
+ pause_state_config=MagicMock(),
+ workflow_run_id="workflow-run-id",
+ workflow_run=SimpleNamespace(triggered_from="app_run"),
+ )
+
+ resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"]
+ assert resumed_entity.stream is True
+ publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.ADVANCED_CHAT)
+
+
+def test_resume_workflow_publishes_events_for_originally_blocking_runs(monkeypatch: pytest.MonkeyPatch):
+ generate_entity = _build_workflow_generate_entity(stream=False)
+
+ generator_instance = MagicMock()
+ response_stream = _single_event_generator({"event": "workflow_finished"})
+ generator_instance.resume.return_value = response_stream
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.WorkflowAppGenerator",
+ lambda: generator_instance,
+ )
+
+ publish_streaming_response = MagicMock()
+ monkeypatch.setattr("tasks.app_generate.workflow_execute_task._publish_streaming_response", publish_streaming_response)
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_execution_repository",
+ lambda **kwargs: MagicMock(),
+ )
+ monkeypatch.setattr(
+ "tasks.app_generate.workflow_execute_task.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ lambda **kwargs: MagicMock(),
+ )
+ workflow_run_repo = MagicMock()
+ pause_entity = MagicMock()
+
+ _resume_workflow(
+ app_model=SimpleNamespace(id="app-id"),
+ workflow=SimpleNamespace(created_by="workflow-owner"),
+ user=MagicMock(),
+ generate_entity=generate_entity,
+ graph_runtime_state=MagicMock(),
+ session_factory=MagicMock(),
+ pause_state_config=MagicMock(),
+ workflow_run_id="workflow-run-id",
+ workflow_run=SimpleNamespace(triggered_from="app_run"),
+ workflow_run_repo=workflow_run_repo,
+ pause_entity=pause_entity,
+ )
+
+ resumed_entity = generator_instance.resume.call_args.kwargs["application_generate_entity"]
+ assert resumed_entity.stream is True
+ publish_streaming_response.assert_called_once_with(response_stream, "workflow-run-id", AppMode.WORKFLOW)
+ workflow_run_repo.delete_workflow_pause.assert_called_once_with(pause_entity)
diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx
index 85cc82fc57..d9ee9bcc1e 100644
--- a/web/app/components/develop/template/template_advanced_chat.en.mdx
+++ b/web/app/components/develop/template/template_advanced_chat.en.mdx
@@ -272,6 +272,12 @@ Chat applications support session persistence, allowing previous chat history to
}'`}
/>
### Blocking Mode
+ Blocking mode can return a normal chat message or a paused workflow response.
+
+ When advanced chat pauses for Human-in-the-Loop, `event` becomes `workflow_paused`.
+ The payload still includes `message_id`, `conversation_id`, `answer`, and `workflow_run_id`, and `data` adds `paused_nodes` plus `reasons`.
+ For `human_input_required`, each reason contains the `form_id` and its `expiration_time`.
+