dify/api/core/app/apps/streaming_utils.py

71 lines
2.5 KiB
Python

from __future__ import annotations
import json
import time
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any
from core.app.entities.task_entities import StreamEvent
from libs.broadcast_channel.channel import Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
def stream_topic_events(
*,
topic: Topic,
idle_timeout: float,
ping_interval: float | None = None,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
terminal_values = _normalize_terminal_events(terminal_events)
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
# on_subscribe fires only after the Redis subscription is active.
# This is used to gate task start and reduce pub/sub race for the first event.
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
event = json.loads(msg)
yield event
if not isinstance(event, dict):
continue
event_type = event.get("event")
if event_type in terminal_values:
return
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:
if isinstance(item, StreamEvent):
values.add(item.value)
else:
values.add(str(item))
return values