mirror of https://github.com/langgenius/dify.git
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from collections.abc import Callable, Generator, Iterable, Mapping
|
|
from typing import Any
|
|
|
|
from core.app.entities.task_entities import StreamEvent
|
|
from libs.broadcast_channel.channel import Topic
|
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
|
|
|
|
|
def stream_topic_events(
|
|
*,
|
|
topic: Topic,
|
|
idle_timeout: float,
|
|
ping_interval: float | None = None,
|
|
on_subscribe: Callable[[], None] | None = None,
|
|
terminal_events: Iterable[str | StreamEvent] | None = None,
|
|
) -> Generator[Mapping[str, Any] | str, None, None]:
|
|
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
|
#
|
|
# This simplify the debugging process as the DevTools in Chrome does not
|
|
# provide complete curl command for pending connections.
|
|
yield StreamEvent.PING.value
|
|
|
|
terminal_values = _normalize_terminal_events(terminal_events)
|
|
last_msg_time = time.time()
|
|
last_ping_time = last_msg_time
|
|
with topic.subscribe() as sub:
|
|
# on_subscribe fires only after the Redis subscription is active.
|
|
# This is used to gate task start and reduce pub/sub race for the first event.
|
|
if on_subscribe is not None:
|
|
on_subscribe()
|
|
while True:
|
|
try:
|
|
msg = sub.receive(timeout=0.1)
|
|
except SubscriptionClosedError:
|
|
return
|
|
if msg is None:
|
|
current_time = time.time()
|
|
if current_time - last_msg_time > idle_timeout:
|
|
return
|
|
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
|
|
yield StreamEvent.PING.value
|
|
last_ping_time = current_time
|
|
continue
|
|
|
|
last_msg_time = time.time()
|
|
last_ping_time = last_msg_time
|
|
event = json.loads(msg)
|
|
yield event
|
|
if not isinstance(event, dict):
|
|
continue
|
|
|
|
event_type = event.get("event")
|
|
if event_type in terminal_values:
|
|
return
|
|
|
|
|
|
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
|
if not terminal_events:
|
|
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
|
values: set[str] = set()
|
|
for item in terminal_events:
|
|
if isinstance(item, StreamEvent):
|
|
values.add(item.value)
|
|
else:
|
|
values.add(str(item))
|
|
return values
|