dify/api/services/workflow_event_snapshot_ser...

461 lines
16 KiB
Python

from __future__ import annotations
import json
import logging
import queue
import threading
import time
from collections.abc import Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
StreamEvent,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.entities import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MessageContext:
conversation_id: str
message_id: str
created_at: int
answer: str | None = None
@dataclass
class BufferState:
queue: queue.Queue[Mapping[str, Any]]
stop_event: threading.Event
done_event: threading.Event
task_id_ready: threading.Event
task_id_hint: str | None = None
def build_workflow_event_stream(
*,
app_mode: AppMode,
workflow_run: WorkflowRun,
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
idle_timeout: float = 300,
ping_interval: float = 10.0,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
message_context = (
_get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None
)
pause_entity: WorkflowPauseEntity | None = None
if workflow_run.status == WorkflowExecutionStatus.PAUSED:
try:
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id)
except Exception:
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
pause_entity = None
resumption_context = _load_resumption_context(pause_entity)
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_run.workflow_id,
# NOTE(QuantumGhost): for events resumption, we only care about
# the execution records from `WORKFLOW_RUN`.
#
# Ideally filtering with `workflow_run_id` is enough. However,
# due to the index of `WorkflowNodeExecution` table, we have to
# add a filter condition of `triggered_from` to
# ensure that we can utilize the index.
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run.id,
)
def _generate() -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
buffer_state = _start_buffering(sub)
try:
task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id)
snapshot_events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=node_snapshots,
task_id=task_id,
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
return
while True:
if buffer_state.done_event.is_set() and buffer_state.queue.empty():
return
try:
event = buffer_state.queue.get(timeout=0.1)
except queue.Empty:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
logger.debug(
"No workflow events received for %s seconds, keeping stream open",
idle_timeout,
)
last_msg_time = current_time
if current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
return
finally:
buffer_state.stop_event.set()
return _generate()
def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None:
with session_maker() as session:
stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at))
message = session.scalar(stmt)
if message is None:
return None
created_at = int(message.created_at.timestamp()) if message.created_at else 0
return MessageContext(
conversation_id=message.conversation_id,
message_id=message.id,
created_at=created_at,
answer=message.answer,
)
def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None:
if pause_entity is None:
return None
try:
raw_state = pause_entity.get_state().decode()
return WorkflowResumptionContext.loads(raw_state)
except Exception:
logger.exception("Failed to load resumption context")
return None
def _resolve_task_id(
resumption_context: WorkflowResumptionContext | None,
buffer_state: BufferState | None,
workflow_run_id: str,
wait_timeout: float = 0.2,
) -> str:
if resumption_context is not None:
generate_entity = resumption_context.get_generate_entity()
if generate_entity.task_id:
return generate_entity.task_id
if buffer_state is None:
return workflow_run_id
if buffer_state.task_id_hint is None:
buffer_state.task_id_ready.wait(timeout=wait_timeout)
if buffer_state.task_id_hint:
return buffer_state.task_id_hint
return workflow_run_id
def _build_snapshot_events(
*,
workflow_run: WorkflowRun,
node_snapshots: Sequence[WorkflowNodeExecutionSnapshot],
task_id: str,
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
workflow_started = _build_workflow_started_event(
workflow_run=workflow_run,
task_id=task_id,
)
_apply_message_context(workflow_started, message_context)
events.append(workflow_started)
if message_context is not None and message_context.answer is not None:
message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer)
_apply_message_context(message_replace, message_context)
events.append(message_replace)
for snapshot in node_snapshots:
node_started = _build_node_started_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_started, message_context)
events.append(node_started)
if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value:
node_finished = _build_node_finished_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_finished, message_context)
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
events.append(pause_event)
return events
def _build_workflow_started_event(
*,
workflow_run: WorkflowRun,
task_id: str,
) -> dict[str, Any]:
response = WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
reason=WorkflowStartReason.INITIAL,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]:
response = MessageReplaceStreamResponse(
task_id=task_id,
answer=answer,
reason="",
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _build_node_started_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> dict[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeStartStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
created_at=created_at,
extras={},
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_node_finished_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> dict[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at
response = NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
process_data=None,
outputs=None,
status=snapshot.status,
error=None,
elapsed_time=snapshot.elapsed_time,
execution_metadata=None,
created_at=created_at,
finished_at=finished_at,
files=[],
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_pause_event(
*,
workflow_run: WorkflowRun,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
if resumption_context is not None:
state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
paused_nodes = state.get_paused_nodes()
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=workflow_run_id,
paused_nodes=paused_nodes,
outputs=outputs,
reasons=reasons,
status=workflow_run.status.value,
created_at=int(workflow_run.created_at.timestamp()),
elapsed_time=float(workflow_run.elapsed_time or 0.0),
total_tokens=int(workflow_run.total_tokens or 0),
total_steps=int(workflow_run.total_steps or 0),
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None:
if message_context is None:
return
payload["conversation_id"] = message_context.conversation_id
payload["message_id"] = message_context.message_id
payload["created_at"] = message_context.created_at
def _start_buffering(subscription) -> BufferState:
buffer_state = BufferState(
queue=queue.Queue(maxsize=2048),
stop_event=threading.Event(),
done_event=threading.Event(),
task_id_ready=threading.Event(),
)
def _worker() -> None:
dropped_count = 0
try:
while not buffer_state.stop_event.is_set():
msg = subscription.receive(timeout=0.1)
if msg is None:
continue
event = _parse_event_message(msg)
if event is None:
continue
task_id = event.get("task_id")
if task_id and buffer_state.task_id_hint is None:
buffer_state.task_id_hint = str(task_id)
buffer_state.task_id_ready.set()
try:
buffer_state.queue.put_nowait(event)
except queue.Full:
dropped_count += 1
try:
buffer_state.queue.get_nowait()
except queue.Empty:
pass
try:
buffer_state.queue.put_nowait(event)
except queue.Full:
continue
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
except Exception:
logger.exception("Failed while buffering workflow events")
finally:
buffer_state.done_event.set()
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
thread.start()
return buffer_state
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
try:
event = json.loads(message)
except json.JSONDecodeError:
logger.warning("Failed to decode workflow event payload")
return None
if not isinstance(event, dict):
return None
return event
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
if include_paused:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False