mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 10:16:40 +08:00
fix:update latest commits (#53)
* test: adding some web tests (#27792) * feat: add validation to prevent saving empty opening statement in conversation opener modal (#27843) * fix(web): improve the consistency of the inputs-form UI (#27837) * fix(web): increase z-index of PortalToFollowElemContent (#27823) * fix: installation_id is missing when in tools page (#27849) * fix: avoid passing empty uniqueIdentifier to InstallFromMarketplace (#27802) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * test: create new test scripts and update some existing test scripts o… (#27850) * feat: change feedback to forum (#27862) * chore: translate i18n files and update type definitions (#27868) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> * Fix/template transformer line number (#27867) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> * bump vite to 6.4.1 (#27877) * Add WEAVIATE_GRPC_ENDPOINT as designed in weaviate migration guide (#27861) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> * Fix: correct DraftWorkflowApi.post response model (#27289) Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> * fix Version 2.0.0-beta.2: Chat annotations Api Error #25506 (#27206) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> * fix jina reader creadential migration command (#27883) * fix agent putout the output of workflow-tool twice (#26835) (#27087) * fix jina reader transform (#27922) * fix: prevent fetch version info in enterprise edition (#27923) * fix(api): fix `VariablePool.get` adding unexpected keys to variable_dictionary (#26767) Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: implement tenant self queue for rag tasks (#27559) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> * fix: bump brotli to 1.2.0 resloved CVE-2025-6176 (#27950) Signed-off-by: kenwoodjw <blackxin55+@gmail.com> * docs: clarify how to obtain workflow_id for version execution (#28007) Signed-off-by: OneZero-Y <aukovyps@163.com> * fix: fix https://github.com/langgenius/dify/issues/27939 (#27985) * fix: the model list encountered two children with the same key (#27956) Co-authored-by: haokai <haokai@shuwen.com> * add onupdate=func.current_timestamp() (#28014) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * chore(deps): bump scipy-stubs from 1.16.2.3 to 1.16.3.0 in /api (#28025) Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Fix typo in weaviate comment, improve time test precision, and add security tests for get-icon utility (#27919) Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * feat: Add Audio Content Support for MCP Tools (#27979) * fix: elasticsearch_vector version (#28028) Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix workflow default updated_at (#28047) * feat(api): Introduce Broadcast Channel (#27835) This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server. It also includes a reference implementation backed by Redis PubSub. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> * fix * back --------- Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com> Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Signed-off-by: OneZero-Y <aukovyps@163.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: aka James4u <smart.jamesjin@gmail.com> Co-authored-by: Novice <novice12185727@gmail.com> Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com> Co-authored-by: Elliott <105957288+Elliott-byte@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: johnny0120 <johnny0120@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Gritty_dev <101377478+codomposer@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: wangjifeng <163279492+kk-wangjifeng@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: Cursx <33718736+Cursx@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: red_sun <56100962+redSun64@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: hj24 <huangjian@dify.ai> Co-authored-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: OneZero-Y <aukovyps@163.com> Co-authored-by: wangxiaolei <fatelei@gmail.com> Co-authored-by: Kenn <kennfalcon@gmail.com> Co-authored-by: haokai <haokai@shuwen.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: Will <vvfriday@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
This commit is contained in:
parent
3f86c863b8
commit
d1a6779bbb
@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
# Return composite sort key: (model_type value, model position index)
|
# Return composite sort key: (model_type value, model position index)
|
||||||
return (model.model_type.value, position_index)
|
return (model.model_type.value, position_index)
|
||||||
|
|
||||||
|
# Deduplicate
|
||||||
|
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||||
|
|
||||||
# Sort using the composite sort key
|
# Sort using the composite sort key
|
||||||
return sorted(provider_models, key=get_sort_key)
|
return sorted(provider_models, key=get_sort_key)
|
||||||
|
|
||||||
|
|||||||
@ -147,7 +147,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
def _get_version(self) -> str:
|
def _get_version(self) -> str:
|
||||||
info = self._client.info()
|
info = self._client.info()
|
||||||
return cast(str, info["version"]["number"])
|
# remove any suffix like "-SNAPSHOT" from the version string
|
||||||
|
return cast(str, info["version"]["number"]).split("-")[0]
|
||||||
|
|
||||||
def _check_version(self):
|
def _check_version(self):
|
||||||
if parse_version(self._version) < parse_version("8.0.0"):
|
if parse_version(self._version) < parse_version("8.0.0"):
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class WeaviateVector(BaseVector):
|
|||||||
|
|
||||||
# Parse gRPC configuration
|
# Parse gRPC configuration
|
||||||
if config.grpc_endpoint:
|
if config.grpc_endpoint:
|
||||||
# Urls without scheme won't be parsed correctly in some python verions,
|
# Urls without scheme won't be parsed correctly in some python versions,
|
||||||
# see https://bugs.python.org/issue27657
|
# see https://bugs.python.org/issue27657
|
||||||
grpc_endpoint_with_scheme = (
|
grpc_endpoint_with_scheme = (
|
||||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPConnectionError
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPTool(Tool):
|
class MCPTool(Tool):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -52,6 +55,11 @@ class MCPTool(Tool):
|
|||||||
yield from self._process_text_content(content)
|
yield from self._process_text_content(content)
|
||||||
elif isinstance(content, ImageContent):
|
elif isinstance(content, ImageContent):
|
||||||
yield self._process_image_content(content)
|
yield self._process_image_content(content)
|
||||||
|
elif isinstance(content, AudioContent):
|
||||||
|
yield self._process_audio_content(content)
|
||||||
|
else:
|
||||||
|
logger.warning("Unsupported content type=%s", type(content))
|
||||||
|
|
||||||
# handle MCP structured output
|
# handle MCP structured output
|
||||||
if self.entity.output_schema and result.structuredContent:
|
if self.entity.output_schema and result.structuredContent:
|
||||||
for k, v in result.structuredContent.items():
|
for k, v in result.structuredContent.items():
|
||||||
@ -97,6 +105,10 @@ class MCPTool(Tool):
|
|||||||
"""Process image content and return a blob message."""
|
"""Process image content and return a blob message."""
|
||||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||||
|
|
||||||
|
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||||
|
"""Process audio content and return a blob message."""
|
||||||
|
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||||
return MCPTool(
|
return MCPTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
|
|||||||
134
api/libs/broadcast_channel/channel.py
Normal file
134
api/libs/broadcast_channel/channel.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Broadcast channel for Pub/Sub messaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from typing import Protocol, Self
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||||
|
"""A subscription to a topic that provides an iterator over received messages.
|
||||||
|
The subscription can be used as a context manager and will automatically
|
||||||
|
close when exiting the context.
|
||||||
|
|
||||||
|
Note: `Subscription` instances are not thread-safe. Each thread should create its own
|
||||||
|
subscription.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self) -> Iterator[bytes]:
|
||||||
|
"""`__iter__` returns an iterator used to consume the message from this subscription.
|
||||||
|
|
||||||
|
If the caller did not enter the context, `__iter__` may lazily perform the setup before
|
||||||
|
yielding messages; otherwise `__enter__` handles it.”
|
||||||
|
|
||||||
|
If the subscription is closed, then the returned iterator exits without
|
||||||
|
raising any error.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""close closes the subscription, releases any resources associated with it."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: types.TracebackType | None,
|
||||||
|
) -> bool | None:
|
||||||
|
self.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||||
|
"""Receive the next message from the broadcast channel.
|
||||||
|
|
||||||
|
If `timeout` is specified, this method returns `None` if no message is
|
||||||
|
received within the given period. If `timeout` is `None`, the call blocks
|
||||||
|
until a message is received.
|
||||||
|
|
||||||
|
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
|
||||||
|
cancel a blocking subscription.
|
||||||
|
|
||||||
|
:param timeout: timeout for receive message, in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The received message as a byte string, or
|
||||||
|
None: If the timeout expires before a message is received.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SubscriptionClosed: If the subscription has already been closed.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Producer(Protocol):
|
||||||
|
"""Producer is an interface for message publishing. It is already bound to a specific topic.
|
||||||
|
|
||||||
|
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish(self, payload: bytes) -> None:
|
||||||
|
"""Publish a message to the bounded topic."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Subscriber(Protocol):
|
||||||
|
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
|
||||||
|
|
||||||
|
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def subscribe(self) -> Subscription:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Topic(Producer, Subscriber, Protocol):
|
||||||
|
"""A named channel for publishing and subscribing to messages.
|
||||||
|
|
||||||
|
Topics provide both read and write access. For restricted access,
|
||||||
|
use as_producer() for write-only view or as_subscriber() for read-only view.
|
||||||
|
|
||||||
|
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_producer(self) -> Producer:
|
||||||
|
"""as_producer creates a write-only view for this topic."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_subscriber(self) -> Subscriber:
|
||||||
|
"""as_subscriber create a read-only view for this topic."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastChannel(Protocol):
|
||||||
|
"""A broadcasting channel is a channel supporting broadcasting semantics.
|
||||||
|
|
||||||
|
Each channel is identified by a topic, different topics are isolated and do not affect each other.
|
||||||
|
|
||||||
|
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
|
||||||
|
a specific topic, all subscription should receive the published message.
|
||||||
|
|
||||||
|
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||||
|
should receive all subsequent messages published.
|
||||||
|
|
||||||
|
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def topic(self, topic: str) -> "Topic":
|
||||||
|
"""topic returns a `Topic` instance for the given topic name."""
|
||||||
|
...
|
||||||
12
api/libs/broadcast_channel/exc.py
Normal file
12
api/libs/broadcast_channel/exc.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
class BroadcastChannelError(Exception):
|
||||||
|
"""`BroadcastChannelError` is the base class for all exceptions related
|
||||||
|
to `BroadcastChannel`."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionClosedError(BroadcastChannelError):
|
||||||
|
"""SubscriptionClosedError means that the subscription has been closed and
|
||||||
|
methods for consuming messages should not be called."""
|
||||||
|
|
||||||
|
pass
|
||||||
3
api/libs/broadcast_channel/redis/__init__.py
Normal file
3
api/libs/broadcast_channel/redis/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .channel import BroadcastChannel
|
||||||
|
|
||||||
|
__all__ = ["BroadcastChannel"]
|
||||||
200
api/libs/broadcast_channel/redis/channel.py
Normal file
200
api/libs/broadcast_channel/redis/channel.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import types
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||||
|
from redis import Redis
|
||||||
|
from redis.client import PubSub
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastChannel:
|
||||||
|
"""
|
||||||
|
Redis Pub/Sub based broadcast channel implementation.
|
||||||
|
|
||||||
|
Provides "at most once" delivery semantics for messages published to channels.
|
||||||
|
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||||
|
|
||||||
|
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client: Redis,
|
||||||
|
):
|
||||||
|
self._client = redis_client
|
||||||
|
|
||||||
|
def topic(self, topic: str) -> "Topic":
|
||||||
|
return Topic(self._client, topic)
|
||||||
|
|
||||||
|
|
||||||
|
class Topic:
|
||||||
|
def __init__(self, redis_client: Redis, topic: str):
|
||||||
|
self._client = redis_client
|
||||||
|
self._topic = topic
|
||||||
|
|
||||||
|
def as_producer(self) -> Producer:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def publish(self, payload: bytes) -> None:
|
||||||
|
self._client.publish(self._topic, payload)
|
||||||
|
|
||||||
|
def as_subscriber(self) -> Subscriber:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def subscribe(self) -> Subscription:
|
||||||
|
return _RedisSubscription(
|
||||||
|
pubsub=self._client.pubsub(),
|
||||||
|
topic=self._topic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RedisSubscription(Subscription):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pubsub: PubSub,
|
||||||
|
topic: str,
|
||||||
|
):
|
||||||
|
# The _pubsub is None only if the subscription is closed.
|
||||||
|
self._pubsub: PubSub | None = pubsub
|
||||||
|
self._topic = topic
|
||||||
|
self._closed = threading.Event()
|
||||||
|
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||||
|
self._dropped_count = 0
|
||||||
|
self._listener_thread: threading.Thread | None = None
|
||||||
|
self._start_lock = threading.Lock()
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
def _start_if_needed(self) -> None:
|
||||||
|
with self._start_lock:
|
||||||
|
if self._started:
|
||||||
|
return
|
||||||
|
if self._closed.is_set():
|
||||||
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||||
|
if self._pubsub is None:
|
||||||
|
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||||
|
|
||||||
|
self._pubsub.subscribe(self._topic)
|
||||||
|
_logger.debug("Subscribed to channel %s", self._topic)
|
||||||
|
|
||||||
|
self._listener_thread = threading.Thread(
|
||||||
|
target=self._listen,
|
||||||
|
name=f"redis-broadcast-{self._topic}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._listener_thread.start()
|
||||||
|
self._started = True
|
||||||
|
|
||||||
|
def _listen(self) -> None:
|
||||||
|
pubsub = self._pubsub
|
||||||
|
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||||
|
while not self._closed.is_set():
|
||||||
|
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||||
|
|
||||||
|
if raw_message is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if raw_message.get("type") != "message":
|
||||||
|
continue
|
||||||
|
|
||||||
|
channel_field = raw_message.get("channel")
|
||||||
|
if isinstance(channel_field, bytes):
|
||||||
|
channel_name = channel_field.decode("utf-8")
|
||||||
|
elif isinstance(channel_field, str):
|
||||||
|
channel_name = channel_field
|
||||||
|
else:
|
||||||
|
channel_name = str(channel_field)
|
||||||
|
|
||||||
|
if channel_name != self._topic:
|
||||||
|
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload_bytes: bytes | None = raw_message.get("data")
|
||||||
|
if not isinstance(payload_bytes, bytes):
|
||||||
|
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._enqueue_message(payload_bytes)
|
||||||
|
|
||||||
|
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||||
|
pubsub.unsubscribe(self._topic)
|
||||||
|
pubsub.close()
|
||||||
|
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||||
|
self._pubsub = None
|
||||||
|
|
||||||
|
def _enqueue_message(self, payload: bytes) -> None:
|
||||||
|
while not self._closed.is_set():
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait(payload)
|
||||||
|
return
|
||||||
|
except queue.Full:
|
||||||
|
try:
|
||||||
|
self._queue.get_nowait()
|
||||||
|
self._dropped_count += 1
|
||||||
|
_logger.debug(
|
||||||
|
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||||
|
self._topic,
|
||||||
|
self._dropped_count,
|
||||||
|
)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
return
|
||||||
|
|
||||||
|
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||||
|
while not self._closed.is_set():
|
||||||
|
try:
|
||||||
|
item = self._queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[bytes]:
|
||||||
|
if self._closed.is_set():
|
||||||
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||||
|
self._start_if_needed()
|
||||||
|
return iter(self._message_iterator())
|
||||||
|
|
||||||
|
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||||
|
if self._closed.is_set():
|
||||||
|
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||||
|
self._start_if_needed()
|
||||||
|
|
||||||
|
try:
|
||||||
|
item = self._queue.get(timeout=timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
self._start_if_needed()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: types.TracebackType | None,
|
||||||
|
) -> bool | None:
|
||||||
|
self.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._closed.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._closed.set()
|
||||||
|
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||||
|
# method should NOT be called concurrently.
|
||||||
|
#
|
||||||
|
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||||
|
listener = self._listener_thread
|
||||||
|
if listener is not None:
|
||||||
|
listener.join(timeout=1.0)
|
||||||
|
self._listener_thread = None
|
||||||
@ -111,7 +111,7 @@ class Account(UserMixin, TypeBase):
|
|||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
|
|
||||||
role: TenantAccountRole | None = field(default=None, init=False)
|
role: TenantAccountRole | None = field(default=None, init=False)
|
||||||
@ -251,7 +251,9 @@ class Tenant(TypeBase):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
def get_accounts(self) -> list[Account]:
|
def get_accounts(self) -> list[Account]:
|
||||||
return list(
|
return list(
|
||||||
@ -290,7 +292,7 @@ class TenantAccountJoin(TypeBase):
|
|||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -311,7 +313,7 @@ class AccountIntegrate(TypeBase):
|
|||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -397,5 +399,5 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
|
|||||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
|
|||||||
@ -61,18 +61,20 @@ class Dataset(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
embedding_model = mapped_column(String(255), nullable=True)
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
embedding_model_provider = mapped_column(String(255), nullable=True)
|
)
|
||||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
|
embedding_model = mapped_column(sa.String(255), nullable=True)
|
||||||
|
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
|
||||||
|
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||||
retrieval_model = mapped_column(sa.JSON, nullable=True)
|
retrieval_model = mapped_column(sa.JSON, nullable=True)
|
||||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
icon_info = mapped_column(sa.JSON, nullable=True)
|
icon_info = mapped_column(sa.JSON, nullable=True)
|
||||||
runtime_mode = mapped_column(String(255), nullable=True, server_default=sa.text("'general'"))
|
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||||
chunk_structure = mapped_column(String(255), nullable=True)
|
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
|
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_documents(self):
|
def total_documents(self):
|
||||||
@ -398,7 +400,9 @@ class Document(Base):
|
|||||||
archived_reason = mapped_column(String(255), nullable=True)
|
archived_reason = mapped_column(String(255), nullable=True)
|
||||||
archived_by = mapped_column(StringUUID, nullable=True)
|
archived_by = mapped_column(StringUUID, nullable=True)
|
||||||
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
doc_type = mapped_column(String(40), nullable=True)
|
doc_type = mapped_column(String(40), nullable=True)
|
||||||
doc_metadata = mapped_column(sa.JSON, nullable=True)
|
doc_metadata = mapped_column(sa.JSON, nullable=True)
|
||||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||||
@ -715,7 +719,9 @@ class DocumentSegment(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
error = mapped_column(LongText, nullable=True)
|
error = mapped_column(LongText, nullable=True)
|
||||||
@ -880,7 +886,7 @@ class ChildChunk(Base):
|
|||||||
)
|
)
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=sa.func.current_timestamp()
|
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
@ -1035,8 +1041,8 @@ class TidbAuthBinding(Base):
|
|||||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'"))
|
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
@ -1087,7 +1093,9 @@ class ExternalKnowledgeApis(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -1140,7 +1148,9 @@ class ExternalKnowledgeBindings(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatasetAutoDisableLog(Base):
|
class DatasetAutoDisableLog(Base):
|
||||||
@ -1196,7 +1206,7 @@ class DatasetMetadata(Base):
|
|||||||
DateTime, nullable=False, server_default=sa.func.current_timestamp()
|
DateTime, nullable=False, server_default=sa.func.current_timestamp()
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=sa.func.current_timestamp()
|
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
@ -1223,44 +1233,48 @@ class DatasetMetadataBinding(Base):
|
|||||||
|
|
||||||
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||||
__tablename__ = "pipeline_built_in_templates"
|
__tablename__ = "pipeline_built_in_templates"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
name = mapped_column(String(255), nullable=False)
|
name = mapped_column(sa.String(255), nullable=False)
|
||||||
description = mapped_column(LongText, nullable=False)
|
description = mapped_column(LongText, nullable=False)
|
||||||
chunk_structure = mapped_column(String(255), nullable=False)
|
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||||
icon = mapped_column(sa.JSON, nullable=False)
|
icon = mapped_column(sa.JSON, nullable=False)
|
||||||
yaml_content = mapped_column(LongText, nullable=False)
|
yaml_content = mapped_column(LongText, nullable=False)
|
||||||
copyright = mapped_column(String(255), nullable=False)
|
copyright = mapped_column(sa.String(255), nullable=False)
|
||||||
privacy_policy = mapped_column(String(255), nullable=False)
|
privacy_policy = mapped_column(sa.String(255), nullable=False)
|
||||||
position = mapped_column(sa.Integer, nullable=False)
|
position = mapped_column(sa.Integer, nullable=False)
|
||||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
language = mapped_column(String(255), nullable=False)
|
language = mapped_column(sa.String(255), nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||||
__tablename__ = "pipeline_customized_templates"
|
__tablename__ = "pipeline_customized_templates"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||||
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||||
name = mapped_column(String(255), nullable=False)
|
name = mapped_column(sa.String(255), nullable=False)
|
||||||
description = mapped_column(LongText, nullable=False)
|
description = mapped_column(LongText, nullable=False)
|
||||||
chunk_structure = mapped_column(String(255), nullable=False)
|
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||||
icon = mapped_column(sa.JSON, nullable=False)
|
icon = mapped_column(sa.JSON, nullable=False)
|
||||||
position = mapped_column(sa.Integer, nullable=False)
|
position = mapped_column(sa.Integer, nullable=False)
|
||||||
yaml_content = mapped_column(LongText, nullable=False)
|
yaml_content = mapped_column(LongText, nullable=False)
|
||||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
language = mapped_column(String(255), nullable=False)
|
language = mapped_column(sa.String(255), nullable=False)
|
||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_user_name(self):
|
def created_user_name(self):
|
||||||
@ -1272,19 +1286,21 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
|||||||
|
|
||||||
class Pipeline(Base): # type: ignore[name-defined]
|
class Pipeline(Base): # type: ignore[name-defined]
|
||||||
__tablename__ = "pipelines"
|
__tablename__ = "pipelines"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
name = mapped_column(String(255), nullable=False)
|
name = mapped_column(sa.String(255), nullable=False)
|
||||||
description = mapped_column(LongText, nullable=False, server_default=db.text("''"))
|
description = mapped_column(LongText, nullable=False, default=sa.text("''"))
|
||||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
created_by = mapped_column(StringUUID, nullable=True)
|
created_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
def retrieve_dataset(self, session: Session):
|
def retrieve_dataset(self, session: Session):
|
||||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||||
@ -1293,16 +1309,16 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||||||
class DocumentPipelineExecutionLog(Base):
|
class DocumentPipelineExecutionLog(Base):
|
||||||
__tablename__ = "document_pipeline_execution_logs"
|
__tablename__ = "document_pipeline_execution_logs"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||||
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
pipeline_id = mapped_column(StringUUID, nullable=False)
|
pipeline_id = mapped_column(StringUUID, nullable=False)
|
||||||
document_id = mapped_column(StringUUID, nullable=False)
|
document_id = mapped_column(StringUUID, nullable=False)
|
||||||
datasource_type = mapped_column(String(255), nullable=False)
|
datasource_type = mapped_column(sa.String(255), nullable=False)
|
||||||
datasource_info = mapped_column(LongText, nullable=False)
|
datasource_info = mapped_column(LongText, nullable=False)
|
||||||
datasource_node_id = mapped_column(String(255), nullable=False)
|
datasource_node_id = mapped_column(sa.String(255), nullable=False)
|
||||||
input_data = mapped_column(sa.JSON, nullable=False)
|
input_data = mapped_column(sa.JSON, nullable=False)
|
||||||
created_by = mapped_column(StringUUID, nullable=True)
|
created_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
@ -1310,7 +1326,7 @@ class DocumentPipelineExecutionLog(Base):
|
|||||||
|
|
||||||
class PipelineRecommendedPlugin(Base):
|
class PipelineRecommendedPlugin(Base):
|
||||||
__tablename__ = "pipeline_recommended_plugins"
|
__tablename__ = "pipeline_recommended_plugins"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
plugin_id = mapped_column(LongText, nullable=False)
|
plugin_id = mapped_column(LongText, nullable=False)
|
||||||
@ -1318,4 +1334,6 @@ class PipelineRecommendedPlugin(Base):
|
|||||||
position = mapped_column(sa.Integer, nullable=False, default=0)
|
position = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
active = mapped_column(sa.Boolean, nullable=False, default=True)
|
active = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|||||||
@ -97,7 +97,9 @@ class App(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=True)
|
created_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -316,7 +318,9 @@ class AppModelConfig(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=True)
|
created_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
opening_statement = mapped_column(LongText)
|
opening_statement = mapped_column(LongText)
|
||||||
suggested_questions = mapped_column(LongText)
|
suggested_questions = mapped_column(LongText)
|
||||||
suggested_questions_after_answer = mapped_column(LongText)
|
suggested_questions_after_answer = mapped_column(LongText)
|
||||||
@ -547,7 +551,9 @@ class RecommendedApp(Base):
|
|||||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
|
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app(self) -> App | None:
|
def app(self) -> App | None:
|
||||||
@ -646,7 +652,9 @@ class Conversation(Base):
|
|||||||
read_account_id = mapped_column(StringUUID)
|
read_account_id = mapped_column(StringUUID)
|
||||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||||
message_annotations = db.relationship(
|
message_annotations = db.relationship(
|
||||||
@ -950,7 +958,9 @@ class Message(Base):
|
|||||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||||
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
@ -1298,7 +1308,9 @@ class MessageFeedback(Base):
|
|||||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def from_account(self) -> Account | None:
|
def from_account(self) -> Account | None:
|
||||||
@ -1380,7 +1392,9 @@ class MessageAnnotation(Base):
|
|||||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||||
account_id = mapped_column(StringUUID, nullable=False)
|
account_id = mapped_column(StringUUID, nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def account(self):
|
def account(self):
|
||||||
@ -1445,7 +1459,9 @@ class AppAnnotationSetting(Base):
|
|||||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def collection_binding_detail(self):
|
def collection_binding_detail(self):
|
||||||
@ -1473,7 +1489,9 @@ class OperationLog(Base):
|
|||||||
content = mapped_column(sa.JSON)
|
content = mapped_column(sa.JSON)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultEndUserSessionID(StrEnum):
|
class DefaultEndUserSessionID(StrEnum):
|
||||||
@ -1512,7 +1530,9 @@ class EndUser(Base, UserMixin):
|
|||||||
|
|
||||||
session_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
session_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AppMCPServer(Base):
|
class AppMCPServer(Base):
|
||||||
@ -1532,7 +1552,9 @@ class AppMCPServer(Base):
|
|||||||
parameters = mapped_column(LongText, nullable=False)
|
parameters = mapped_column(LongText, nullable=False)
|
||||||
|
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_server_code(n: int) -> str:
|
def generate_server_code(n: int) -> str:
|
||||||
@ -1578,7 +1600,9 @@ class Site(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=True)
|
created_by = mapped_column(StringUUID, nullable=True)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
code = mapped_column(String(255))
|
code = mapped_column(String(255))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -1,64 +1,67 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import String
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from libs.uuid_utils import uuidv7
|
from libs.uuid_utils import uuidv7
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import db
|
|
||||||
from .types import LongText, StringUUID
|
from .types import LongText, StringUUID
|
||||||
|
|
||||||
|
|
||||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||||
__tablename__ = "datasource_oauth_params"
|
__tablename__ = "datasource_oauth_params"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||||
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
system_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
system_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class DatasourceProvider(Base):
|
class DatasourceProvider(Base):
|
||||||
__tablename__ = "datasource_providers"
|
__tablename__ = "datasource_providers"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||||
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||||
)
|
)
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
provider: Mapped[str] = mapped_column(String(128), nullable=False)
|
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
|
||||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
auth_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
encrypted_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
encrypted_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||||
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
|
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
|
||||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatasourceOauthTenantParamConfig(Base):
|
class DatasourceOauthTenantParamConfig(Base):
|
||||||
__tablename__ = "datasource_oauth_tenant_params"
|
__tablename__ = "datasource_oauth_tenant_params"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
client_params: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default={})
|
client_params: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default={})
|
||||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|||||||
@ -75,7 +75,9 @@ class Provider(Base):
|
|||||||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
|
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
@ -138,7 +140,9 @@ class ProviderModel(Base):
|
|||||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def credential(self):
|
def credential(self):
|
||||||
@ -173,7 +177,9 @@ class TenantDefaultModel(Base):
|
|||||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TenantPreferredModelProvider(Base):
|
class TenantPreferredModelProvider(Base):
|
||||||
@ -188,7 +194,9 @@ class TenantPreferredModelProvider(Base):
|
|||||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderOrder(Base):
|
class ProviderOrder(Base):
|
||||||
@ -215,7 +223,9 @@ class ProviderOrder(Base):
|
|||||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelSetting(Base):
|
class ProviderModelSetting(Base):
|
||||||
@ -237,7 +247,9 @@ class ProviderModelSetting(Base):
|
|||||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||||
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoadBalancingModelConfig(Base):
|
class LoadBalancingModelConfig(Base):
|
||||||
@ -262,7 +274,9 @@ class LoadBalancingModelConfig(Base):
|
|||||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
||||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderCredential(Base):
|
class ProviderCredential(Base):
|
||||||
@ -282,7 +296,9 @@ class ProviderCredential(Base):
|
|||||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelCredential(Base):
|
class ProviderModelCredential(Base):
|
||||||
@ -310,4 +326,6 @@ class ProviderModelCredential(Base):
|
|||||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|||||||
@ -140,8 +140,9 @@ class Workflow(Base):
|
|||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=naive_utc_now(),
|
default=func.current_timestamp(),
|
||||||
server_onupdate=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
|
onupdate=func.current_timestamp(),
|
||||||
)
|
)
|
||||||
_environment_variables: Mapped[str] = mapped_column(
|
_environment_variables: Mapped[str] = mapped_column(
|
||||||
"environment_variables", LongText, nullable=False, default="{}"
|
"environment_variables", LongText, nullable=False, default="{}"
|
||||||
|
|||||||
@ -0,0 +1,311 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for Redis broadcast channel implementation using TestContainers.
|
||||||
|
|
||||||
|
This test suite covers real Redis interactions including:
|
||||||
|
- Multiple producer/consumer scenarios
|
||||||
|
- Network failure scenarios
|
||||||
|
- Performance under load
|
||||||
|
- Real-world usage patterns
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import redis
|
||||||
|
from testcontainers.redis import RedisContainer
|
||||||
|
|
||||||
|
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||||
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||||
|
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisBroadcastChannelIntegration:
|
||||||
|
"""Integration tests for Redis broadcast channel with real Redis instance."""
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def redis_container(self) -> Iterator[RedisContainer]:
|
||||||
|
"""Create a Redis container for integration testing."""
|
||||||
|
with RedisContainer(image="redis:6-alpine") as container:
|
||||||
|
yield container
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
|
||||||
|
"""Create a Redis client connected to the test container."""
|
||||||
|
host = redis_container.get_container_host_ip()
|
||||||
|
port = redis_container.get_exposed_port(6379)
|
||||||
|
return redis.Redis(host=host, port=port, decode_responses=False)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
|
||||||
|
"""Create a BroadcastChannel instance with real Redis client."""
|
||||||
|
return RedisBroadcastChannel(redis_client)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_test_topic_name(cls):
|
||||||
|
return f"test_topic_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# ==================== Basic Functionality Tests ===================='
|
||||||
|
|
||||||
|
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel):
|
||||||
|
topic_name = self._get_test_topic_name()
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
subscription = topic.subscribe()
|
||||||
|
consuming_event = threading.Event()
|
||||||
|
|
||||||
|
def consume():
|
||||||
|
msgs = []
|
||||||
|
consuming_event.set()
|
||||||
|
for msg in subscription:
|
||||||
|
msgs.append(msg)
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
producer_future = executor.submit(consume)
|
||||||
|
consuming_event.wait()
|
||||||
|
subscription.close()
|
||||||
|
msgs = producer_future.result(timeout=1)
|
||||||
|
assert msgs == []
|
||||||
|
|
||||||
|
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel):
|
||||||
|
"""Test complete end-to-end messaging flow."""
|
||||||
|
topic_name = "test-topic"
|
||||||
|
message = b"hello world"
|
||||||
|
|
||||||
|
# Create producer and subscriber
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
producer = topic.as_producer()
|
||||||
|
subscription = topic.subscribe()
|
||||||
|
|
||||||
|
# Publish and receive message
|
||||||
|
|
||||||
|
def producer_thread():
|
||||||
|
time.sleep(0.1) # Small delay to ensure subscriber is ready
|
||||||
|
producer.publish(message)
|
||||||
|
time.sleep(0.1)
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
def consumer_thread() -> list[bytes]:
|
||||||
|
received_messages = []
|
||||||
|
for msg in subscription:
|
||||||
|
received_messages.append(msg)
|
||||||
|
return received_messages
|
||||||
|
|
||||||
|
# Run producer and consumer
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
producer_future = executor.submit(producer_thread)
|
||||||
|
consumer_future = executor.submit(consumer_thread)
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
producer_future.result(timeout=5.0)
|
||||||
|
received_messages = consumer_future.result(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(received_messages) == 1
|
||||||
|
assert received_messages[0] == message
|
||||||
|
|
||||||
|
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
|
||||||
|
"""Test message broadcasting to multiple subscribers."""
|
||||||
|
topic_name = "broadcast-topic"
|
||||||
|
message = b"broadcast message"
|
||||||
|
subscriber_count = 5
|
||||||
|
|
||||||
|
# Create producer and multiple subscribers
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
producer = topic.as_producer()
|
||||||
|
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
|
||||||
|
|
||||||
|
def producer_thread():
|
||||||
|
time.sleep(0.2) # Allow all subscribers to connect
|
||||||
|
producer.publish(message)
|
||||||
|
time.sleep(0.2)
|
||||||
|
for sub in subscriptions:
|
||||||
|
sub.close()
|
||||||
|
|
||||||
|
def consumer_thread(subscription: Subscription) -> list[bytes]:
|
||||||
|
received_msgs = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = subscription.receive(0.1)
|
||||||
|
except SubscriptionClosedError:
|
||||||
|
break
|
||||||
|
if msg is None:
|
||||||
|
continue
|
||||||
|
received_msgs.append(msg)
|
||||||
|
if len(received_msgs) >= 1:
|
||||||
|
break
|
||||||
|
return received_msgs
|
||||||
|
|
||||||
|
# Run producer and consumers
|
||||||
|
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
|
||||||
|
producer_future = executor.submit(producer_thread)
|
||||||
|
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
producer_future.result(timeout=10.0)
|
||||||
|
msgs_by_consumers = []
|
||||||
|
for future in as_completed(consumer_futures, timeout=10.0):
|
||||||
|
msgs_by_consumers.append(future.result())
|
||||||
|
|
||||||
|
# Close all subscriptions
|
||||||
|
for subscription in subscriptions:
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
# Verify all subscribers received the message
|
||||||
|
for msgs in msgs_by_consumers:
|
||||||
|
assert len(msgs) == 1
|
||||||
|
assert msgs[0] == message
|
||||||
|
|
||||||
|
def test_topic_isolation(self, broadcast_channel: BroadcastChannel):
|
||||||
|
"""Test that different topics are isolated from each other."""
|
||||||
|
topic1_name = "topic1"
|
||||||
|
topic2_name = "topic2"
|
||||||
|
message1 = b"message for topic1"
|
||||||
|
message2 = b"message for topic2"
|
||||||
|
|
||||||
|
# Create producers and subscribers for different topics
|
||||||
|
topic1 = broadcast_channel.topic(topic1_name)
|
||||||
|
topic2 = broadcast_channel.topic(topic2_name)
|
||||||
|
|
||||||
|
def producer_thread():
|
||||||
|
time.sleep(0.1)
|
||||||
|
topic1.publish(message1)
|
||||||
|
topic2.publish(message2)
|
||||||
|
|
||||||
|
def consumer_by_thread(topic: Topic) -> list[bytes]:
|
||||||
|
subscription = topic.subscribe()
|
||||||
|
received = []
|
||||||
|
with subscription:
|
||||||
|
for msg in subscription:
|
||||||
|
received.append(msg)
|
||||||
|
if len(received) >= 1:
|
||||||
|
break
|
||||||
|
return received
|
||||||
|
|
||||||
|
# Run all threads
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||||
|
producer_future = executor.submit(producer_thread)
|
||||||
|
consumer1_future = executor.submit(consumer_by_thread, topic1)
|
||||||
|
consumer2_future = executor.submit(consumer_by_thread, topic2)
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
producer_future.result(timeout=5.0)
|
||||||
|
received_by_topic1 = consumer1_future.result(timeout=5.0)
|
||||||
|
received_by_topic2 = consumer2_future.result(timeout=5.0)
|
||||||
|
|
||||||
|
# Verify topic isolation
|
||||||
|
assert len(received_by_topic1) == 1
|
||||||
|
assert len(received_by_topic2) == 1
|
||||||
|
assert received_by_topic1[0] == message1
|
||||||
|
assert received_by_topic2[0] == message2
|
||||||
|
|
||||||
|
# ==================== Performance Tests ====================
|
||||||
|
|
||||||
|
def test_concurrent_producers(self, broadcast_channel: BroadcastChannel):
|
||||||
|
"""Test multiple producers publishing to the same topic."""
|
||||||
|
topic_name = "concurrent-producers-topic"
|
||||||
|
producer_count = 5
|
||||||
|
messages_per_producer = 5
|
||||||
|
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
subscription = topic.subscribe()
|
||||||
|
|
||||||
|
expected_total = producer_count * messages_per_producer
|
||||||
|
consumer_ready = threading.Event()
|
||||||
|
|
||||||
|
def producer_thread(producer_idx: int) -> set[bytes]:
|
||||||
|
producer = topic.as_producer()
|
||||||
|
produced = set()
|
||||||
|
for i in range(messages_per_producer):
|
||||||
|
message = f"producer_{producer_idx}_msg_{i}".encode()
|
||||||
|
produced.add(message)
|
||||||
|
producer.publish(message)
|
||||||
|
time.sleep(0.001) # Small delay to avoid overwhelming
|
||||||
|
return produced
|
||||||
|
|
||||||
|
def consumer_thread() -> set[bytes]:
|
||||||
|
received_msgs: set[bytes] = set()
|
||||||
|
with subscription:
|
||||||
|
consumer_ready.set()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = subscription.receive(timeout=0.1)
|
||||||
|
except SubscriptionClosedError:
|
||||||
|
break
|
||||||
|
if msg is None:
|
||||||
|
if len(received_msgs) >= expected_total:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
received_msgs.add(msg)
|
||||||
|
return received_msgs
|
||||||
|
|
||||||
|
# Run producers and consumer
|
||||||
|
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
|
||||||
|
consumer_future = executor.submit(consumer_thread)
|
||||||
|
consumer_ready.wait()
|
||||||
|
producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)]
|
||||||
|
|
||||||
|
sent_msgs: set[bytes] = set()
|
||||||
|
# Wait for completion
|
||||||
|
for future in as_completed(producer_futures, timeout=30.0):
|
||||||
|
sent_msgs.update(future.result())
|
||||||
|
|
||||||
|
subscription.close()
|
||||||
|
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||||
|
|
||||||
|
# Verify message content
|
||||||
|
assert sent_msgs == consumer_received_msgs
|
||||||
|
|
||||||
|
# ==================== Resource Management Tests ====================
|
||||||
|
|
||||||
|
def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis):
|
||||||
|
"""Test proper cleanup of subscription resources."""
|
||||||
|
topic_name = "cleanup-test-topic"
|
||||||
|
|
||||||
|
# Create multiple subscriptions
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
|
||||||
|
def _consume(sub: Subscription):
|
||||||
|
for i in sub:
|
||||||
|
pass
|
||||||
|
|
||||||
|
subscriptions = []
|
||||||
|
for i in range(5):
|
||||||
|
subscription = topic.subscribe()
|
||||||
|
subscriptions.append(subscription)
|
||||||
|
|
||||||
|
# Start all subscriptions
|
||||||
|
thread = threading.Thread(target=_consume, args=(subscription,))
|
||||||
|
thread.start()
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Verify subscriptions are active
|
||||||
|
pubsub_info = redis_client.pubsub_numsub(topic_name)
|
||||||
|
# pubsub_numsub returns list of tuples, find our topic
|
||||||
|
topic_subscribers = 0
|
||||||
|
for channel, count in pubsub_info:
|
||||||
|
# the channel name returned by redis is bytes.
|
||||||
|
if channel == topic_name.encode():
|
||||||
|
topic_subscribers = count
|
||||||
|
break
|
||||||
|
assert topic_subscribers >= 5
|
||||||
|
|
||||||
|
# Close all subscriptions
|
||||||
|
for subscription in subscriptions:
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
# Wait a bit for cleanup
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Verify subscriptions are cleaned up
|
||||||
|
pubsub_info_after = redis_client.pubsub_numsub(topic_name)
|
||||||
|
topic_subscribers_after = 0
|
||||||
|
for channel, count in pubsub_info_after:
|
||||||
|
if channel == topic_name.encode():
|
||||||
|
topic_subscribers_after = count
|
||||||
|
break
|
||||||
|
assert topic_subscribers_after == 0
|
||||||
@ -0,0 +1,514 @@
|
|||||||
|
"""
|
||||||
|
Comprehensive unit tests for Redis broadcast channel implementation.
|
||||||
|
|
||||||
|
This test suite covers all aspects of the Redis broadcast channel including:
|
||||||
|
- Basic functionality and contract compliance
|
||||||
|
- Error handling and edge cases
|
||||||
|
- Thread safety and concurrency
|
||||||
|
- Resource management and cleanup
|
||||||
|
- Performance and reliability scenarios
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
|
||||||
|
from libs.broadcast_channel.redis.channel import (
|
||||||
|
BroadcastChannel as RedisBroadcastChannel,
|
||||||
|
)
|
||||||
|
from libs.broadcast_channel.redis.channel import (
|
||||||
|
Topic,
|
||||||
|
_RedisSubscription,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBroadcastChannel:
|
||||||
|
"""Test cases for the main BroadcastChannel class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_client(self) -> MagicMock:
|
||||||
|
"""Create a mock Redis client for testing."""
|
||||||
|
client = MagicMock()
|
||||||
|
client.pubsub.return_value = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
|
||||||
|
"""Create a BroadcastChannel instance with mock Redis client."""
|
||||||
|
return RedisBroadcastChannel(mock_redis_client)
|
||||||
|
|
||||||
|
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
|
||||||
|
"""Test that topic() method returns a Topic instance with correct parameters."""
|
||||||
|
topic_name = "test-topic"
|
||||||
|
topic = broadcast_channel.topic(topic_name)
|
||||||
|
|
||||||
|
assert isinstance(topic, Topic)
|
||||||
|
assert topic._client == mock_redis_client
|
||||||
|
assert topic._topic == topic_name
|
||||||
|
|
||||||
|
def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel):
|
||||||
|
"""Test that different topic names create isolated Topic instances."""
|
||||||
|
topic1 = broadcast_channel.topic("topic1")
|
||||||
|
topic2 = broadcast_channel.topic("topic2")
|
||||||
|
|
||||||
|
assert topic1 is not topic2
|
||||||
|
assert topic1._topic == "topic1"
|
||||||
|
assert topic2._topic == "topic2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopic:
|
||||||
|
"""Test cases for the Topic class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_client(self) -> MagicMock:
|
||||||
|
"""Create a mock Redis client for testing."""
|
||||||
|
client = MagicMock()
|
||||||
|
client.pubsub.return_value = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def topic(self, mock_redis_client: MagicMock) -> Topic:
|
||||||
|
"""Create a Topic instance for testing."""
|
||||||
|
return Topic(mock_redis_client, "test-topic")
|
||||||
|
|
||||||
|
def test_as_producer_returns_self(self, topic: Topic):
|
||||||
|
"""Test that as_producer() returns self as Producer interface."""
|
||||||
|
producer = topic.as_producer()
|
||||||
|
assert producer is topic
|
||||||
|
# Producer is a Protocol, check duck typing instead
|
||||||
|
assert hasattr(producer, "publish")
|
||||||
|
|
||||||
|
def test_as_subscriber_returns_self(self, topic: Topic):
|
||||||
|
"""Test that as_subscriber() returns self as Subscriber interface."""
|
||||||
|
subscriber = topic.as_subscriber()
|
||||||
|
assert subscriber is topic
|
||||||
|
# Subscriber is a Protocol, check duck typing instead
|
||||||
|
assert hasattr(subscriber, "subscribe")
|
||||||
|
|
||||||
|
def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock):
|
||||||
|
"""Test that publish() calls Redis PUBLISH with correct parameters."""
|
||||||
|
payload = b"test message"
|
||||||
|
topic.publish(payload)
|
||||||
|
|
||||||
|
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class SubscriptionTestCase:
|
||||||
|
"""Test case data for subscription tests."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
buffer_size: int
|
||||||
|
payload: bytes
|
||||||
|
expected_messages: list[bytes]
|
||||||
|
should_drop: bool = False
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisSubscription:
|
||||||
|
"""Test cases for the _RedisSubscription class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pubsub(self) -> MagicMock:
|
||||||
|
"""Create a mock PubSub instance for testing."""
|
||||||
|
pubsub = MagicMock()
|
||||||
|
pubsub.subscribe = MagicMock()
|
||||||
|
pubsub.unsubscribe = MagicMock()
|
||||||
|
pubsub.close = MagicMock()
|
||||||
|
pubsub.get_message = MagicMock()
|
||||||
|
return pubsub
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
||||||
|
"""Create a _RedisSubscription instance for testing."""
|
||||||
|
subscription = _RedisSubscription(
|
||||||
|
pubsub=mock_pubsub,
|
||||||
|
topic="test-topic",
|
||||||
|
)
|
||||||
|
yield subscription
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription:
|
||||||
|
"""Create a subscription that has been started."""
|
||||||
|
subscription._start_if_needed()
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
# ==================== Lifecycle Tests ====================
|
||||||
|
|
||||||
|
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||||
|
"""Test that subscription is properly initialized."""
|
||||||
|
subscription = _RedisSubscription(
|
||||||
|
pubsub=mock_pubsub,
|
||||||
|
topic="test-topic",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert subscription._pubsub is mock_pubsub
|
||||||
|
assert subscription._topic == "test-topic"
|
||||||
|
assert not subscription._closed.is_set()
|
||||||
|
assert subscription._dropped_count == 0
|
||||||
|
assert subscription._listener_thread is None
|
||||||
|
assert not subscription._started
|
||||||
|
|
||||||
|
def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that _start_if_needed() properly starts subscription on first call."""
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||||
|
assert subscription._started is True
|
||||||
|
assert subscription._listener_thread is not None
|
||||||
|
|
||||||
|
def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription):
|
||||||
|
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
|
||||||
|
original_thread = started_subscription._listener_thread
|
||||||
|
started_subscription._start_if_needed()
|
||||||
|
|
||||||
|
# Should not create new thread or generator
|
||||||
|
assert started_subscription._listener_thread is original_thread
|
||||||
|
|
||||||
|
def test_start_if_needed_when_closed(self, subscription: _RedisSubscription):
|
||||||
|
"""Test that _start_if_needed() raises error when subscription is closed."""
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
|
||||||
|
"""Test that _start_if_needed() raises error when pubsub is None."""
|
||||||
|
subscription._pubsub = None
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that subscription works as context manager."""
|
||||||
|
with subscription as sub:
|
||||||
|
assert sub is subscription
|
||||||
|
assert subscription._started is True
|
||||||
|
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||||
|
|
||||||
|
def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that close() is idempotent and can be called multiple times."""
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
# Close multiple times
|
||||||
|
subscription.close()
|
||||||
|
subscription.close()
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
# Should only cleanup once
|
||||||
|
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||||
|
mock_pubsub.close.assert_called_once()
|
||||||
|
assert subscription._pubsub is None
|
||||||
|
assert subscription._closed.is_set()
|
||||||
|
|
||||||
|
def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that close() properly cleans up all resources."""
|
||||||
|
subscription._start_if_needed()
|
||||||
|
thread = subscription._listener_thread
|
||||||
|
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
# Verify cleanup
|
||||||
|
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||||
|
mock_pubsub.close.assert_called_once()
|
||||||
|
assert subscription._pubsub is None
|
||||||
|
assert subscription._listener_thread is None
|
||||||
|
|
||||||
|
# Wait for thread to finish (with timeout)
|
||||||
|
if thread and thread.is_alive():
|
||||||
|
thread.join(timeout=1.0)
|
||||||
|
assert not thread.is_alive()
|
||||||
|
|
||||||
|
# ==================== Message Processing Tests ====================
|
||||||
|
|
||||||
|
def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription):
|
||||||
|
"""Test message iterator behavior with messages in queue."""
|
||||||
|
test_messages = [b"msg1", b"msg2", b"msg3"]
|
||||||
|
|
||||||
|
# Add messages to queue
|
||||||
|
for msg in test_messages:
|
||||||
|
started_subscription._queue.put_nowait(msg)
|
||||||
|
|
||||||
|
# Iterate through messages
|
||||||
|
iterator = iter(started_subscription)
|
||||||
|
received_messages = []
|
||||||
|
|
||||||
|
for msg in iterator:
|
||||||
|
received_messages.append(msg)
|
||||||
|
if len(received_messages) >= len(test_messages):
|
||||||
|
break
|
||||||
|
|
||||||
|
assert received_messages == test_messages
|
||||||
|
|
||||||
|
def test_message_iterator_when_closed(self, subscription: _RedisSubscription):
|
||||||
|
"""Test that iterator raises error when subscription is closed."""
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
|
||||||
|
iter(subscription)
|
||||||
|
|
||||||
|
# ==================== Message Enqueue Tests ====================
|
||||||
|
|
||||||
|
def test_enqueue_message_success(self, started_subscription: _RedisSubscription):
|
||||||
|
"""Test successful message enqueue."""
|
||||||
|
payload = b"test message"
|
||||||
|
|
||||||
|
started_subscription._enqueue_message(payload)
|
||||||
|
|
||||||
|
assert started_subscription._queue.qsize() == 1
|
||||||
|
assert started_subscription._queue.get_nowait() == payload
|
||||||
|
|
||||||
|
def test_enqueue_message_when_closed(self, subscription: _RedisSubscription):
|
||||||
|
"""Test message enqueue when subscription is closed."""
|
||||||
|
subscription.close()
|
||||||
|
payload = b"test message"
|
||||||
|
|
||||||
|
# Should not raise exception, but should not enqueue
|
||||||
|
subscription._enqueue_message(payload)
|
||||||
|
|
||||||
|
assert subscription._queue.empty()
|
||||||
|
|
||||||
|
def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription):
|
||||||
|
"""Test message enqueue with full queue (dropping behavior)."""
|
||||||
|
# Fill the queue
|
||||||
|
for i in range(started_subscription._queue.maxsize):
|
||||||
|
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
|
||||||
|
|
||||||
|
# Try to enqueue new message (should drop oldest)
|
||||||
|
new_message = b"new_message"
|
||||||
|
started_subscription._enqueue_message(new_message)
|
||||||
|
|
||||||
|
# Should have dropped one message and added new one
|
||||||
|
assert started_subscription._dropped_count == 1
|
||||||
|
|
||||||
|
# New message should be in queue
|
||||||
|
messages = []
|
||||||
|
while not started_subscription._queue.empty():
|
||||||
|
messages.append(started_subscription._queue.get_nowait())
|
||||||
|
|
||||||
|
assert new_message in messages
|
||||||
|
|
||||||
|
# ==================== Listener Thread Tests ====================
|
||||||
|
|
||||||
|
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
|
||||||
|
def test_listener_thread_normal_operation(
|
||||||
|
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
|
||||||
|
):
|
||||||
|
"""Test listener thread normal operation."""
|
||||||
|
# Mock message from Redis
|
||||||
|
mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"}
|
||||||
|
mock_pubsub.get_message.return_value = mock_message
|
||||||
|
|
||||||
|
# Start listener
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
# Wait a bit for processing
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Verify message was processed
|
||||||
|
assert not subscription._queue.empty()
|
||||||
|
assert subscription._queue.get_nowait() == b"test payload"
|
||||||
|
|
||||||
|
def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that listener thread ignores subscribe/unsubscribe messages."""
|
||||||
|
mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1}
|
||||||
|
mock_pubsub.get_message.return_value = mock_message
|
||||||
|
|
||||||
|
subscription._start_if_needed()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Should not enqueue subscribe messages
|
||||||
|
assert subscription._queue.empty()
|
||||||
|
|
||||||
|
def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that listener thread ignores messages from wrong channels."""
|
||||||
|
mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"}
|
||||||
|
mock_pubsub.get_message.return_value = mock_message
|
||||||
|
|
||||||
|
subscription._start_if_needed()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Should not enqueue messages from wrong channels
|
||||||
|
assert subscription._queue.empty()
|
||||||
|
|
||||||
|
def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that listener thread handles Redis exceptions gracefully."""
|
||||||
|
mock_pubsub.get_message.side_effect = Exception("Redis error")
|
||||||
|
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
# Wait for thread to handle exception
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
# Thread should still be alive but not processing
|
||||||
|
assert subscription._listener_thread is not None
|
||||||
|
assert not subscription._listener_thread.is_alive()
|
||||||
|
|
||||||
|
def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||||
|
"""Test that listener thread stops when subscription is closed."""
|
||||||
|
subscription._start_if_needed()
|
||||||
|
thread = subscription._listener_thread
|
||||||
|
|
||||||
|
# Close subscription
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
# Wait for thread to finish
|
||||||
|
if thread is not None and thread.is_alive():
|
||||||
|
thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
assert thread is None or not thread.is_alive()
|
||||||
|
|
||||||
|
# ==================== Table-driven Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
SubscriptionTestCase(
|
||||||
|
name="basic_message",
|
||||||
|
buffer_size=5,
|
||||||
|
payload=b"hello world",
|
||||||
|
expected_messages=[b"hello world"],
|
||||||
|
description="Basic message publishing and receiving",
|
||||||
|
),
|
||||||
|
SubscriptionTestCase(
|
||||||
|
name="empty_message",
|
||||||
|
buffer_size=5,
|
||||||
|
payload=b"",
|
||||||
|
expected_messages=[b""],
|
||||||
|
description="Empty message handling",
|
||||||
|
),
|
||||||
|
SubscriptionTestCase(
|
||||||
|
name="large_message",
|
||||||
|
buffer_size=5,
|
||||||
|
payload=b"x" * 10000,
|
||||||
|
expected_messages=[b"x" * 10000],
|
||||||
|
description="Large message handling",
|
||||||
|
),
|
||||||
|
SubscriptionTestCase(
|
||||||
|
name="unicode_message",
|
||||||
|
buffer_size=5,
|
||||||
|
payload="你好世界".encode(),
|
||||||
|
expected_messages=["你好世界".encode()],
|
||||||
|
description="Unicode message handling",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||||
|
"""Test various subscription scenarios using table-driven approach."""
|
||||||
|
subscription = _RedisSubscription(
|
||||||
|
pubsub=mock_pubsub,
|
||||||
|
topic="test-topic",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate receiving message
|
||||||
|
mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload}
|
||||||
|
mock_pubsub.get_message.return_value = mock_message
|
||||||
|
|
||||||
|
try:
|
||||||
|
with subscription:
|
||||||
|
# Wait for message processing
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Collect received messages
|
||||||
|
received = []
|
||||||
|
for msg in subscription:
|
||||||
|
received.append(msg)
|
||||||
|
if len(received) >= len(test_case.expected_messages):
|
||||||
|
break
|
||||||
|
|
||||||
|
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
|
||||||
|
finally:
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription):
|
||||||
|
"""Test concurrent close and enqueue operations."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def close_subscription():
|
||||||
|
try:
|
||||||
|
time.sleep(0.05) # Small delay
|
||||||
|
started_subscription.close()
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
def enqueue_messages():
|
||||||
|
try:
|
||||||
|
for i in range(50):
|
||||||
|
started_subscription._enqueue_message(f"msg_{i}".encode())
|
||||||
|
time.sleep(0.001)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
# Start threads
|
||||||
|
close_thread = threading.Thread(target=close_subscription)
|
||||||
|
enqueue_thread = threading.Thread(target=enqueue_messages)
|
||||||
|
|
||||||
|
close_thread.start()
|
||||||
|
enqueue_thread.start()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
close_thread.join(timeout=2.0)
|
||||||
|
enqueue_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Should not have any errors (operations should be safe)
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
# ==================== Error Handling Tests ====================
|
||||||
|
|
||||||
|
def test_iterator_after_close(self, subscription: _RedisSubscription):
|
||||||
|
"""Test iterator behavior after close."""
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||||
|
iter(subscription)
|
||||||
|
|
||||||
|
def test_start_after_close(self, subscription: _RedisSubscription):
|
||||||
|
"""Test start attempts after close."""
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
|
||||||
|
"""Test operations when pubsub is None."""
|
||||||
|
subscription._pubsub = None
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||||
|
subscription._start_if_needed()
|
||||||
|
|
||||||
|
# Close should still work
|
||||||
|
subscription.close() # Should not raise
|
||||||
|
|
||||||
|
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||||
|
"""Test various channel name formats."""
|
||||||
|
channel_names = [
|
||||||
|
"simple",
|
||||||
|
"with-dashes",
|
||||||
|
"with_underscores",
|
||||||
|
"with.numbers",
|
||||||
|
"WITH.UPPERCASE",
|
||||||
|
"mixed-CASE_name",
|
||||||
|
"very.long.channel.name.with.multiple.parts",
|
||||||
|
]
|
||||||
|
|
||||||
|
for channel_name in channel_names:
|
||||||
|
subscription = _RedisSubscription(
|
||||||
|
pubsub=mock_pubsub,
|
||||||
|
topic=channel_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
subscription._start_if_needed()
|
||||||
|
mock_pubsub.subscribe.assert_called_with(channel_name)
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
def test_received_on_closed_subscription(self, subscription: _RedisSubscription):
|
||||||
|
subscription.close()
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError):
|
||||||
|
subscription.receive()
|
||||||
4585
api/uv.lock
generated
4585
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -74,7 +74,8 @@ Chat applications support session persistence, allowing previous chat history to
|
|||||||
If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`.
|
If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`.
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||||
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.
|
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.<br/>
|
||||||
|
How to obtain: In the version history interface, click the copy icon on the right side of each version entry to copy the complete workflow ID.
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='trace_id' type='string' key='trace_id'>
|
<Property name='trace_id' type='string' key='trace_id'>
|
||||||
(Optional) Trace ID. Used for integration with existing business trace components to achieve end-to-end distributed tracing. If not provided, the system will automatically generate a trace_id. Supports the following three ways to pass, in order of priority:<br/>
|
(Optional) Trace ID. Used for integration with existing business trace components to achieve end-to-end distributed tracing. If not provided, the system will automatically generate a trace_id. Supports the following three ways to pass, in order of priority:<br/>
|
||||||
|
|||||||
@ -74,7 +74,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
`false`に設定すると、会話のリネームAPIを呼び出し、`auto_generate`を`true`に設定することで非同期タイトル生成を実現できます。
|
`false`に設定すると、会話のリネームAPIを呼び出し、`auto_generate`を`true`に設定することで非同期タイトル生成を実現できます。
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||||
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。
|
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。<br/>
|
||||||
|
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='trace_id' type='string' key='trace_id'>
|
<Property name='trace_id' type='string' key='trace_id'>
|
||||||
(オプション)トレースID。既存の業務システムのトレースコンポーネントと連携し、エンドツーエンドの分散トレーシングを実現するために使用します。指定がない場合、システムが自動的に trace_id を生成します。以下の3つの方法で渡すことができ、優先順位は次のとおりです:<br/>
|
(オプション)トレースID。既存の業務システムのトレースコンポーネントと連携し、エンドツーエンドの分散トレーシングを実現するために使用します。指定がない場合、システムが自動的に trace_id を生成します。以下の3つの方法で渡すことができ、優先順位は次のとおりです:<br/>
|
||||||
|
|||||||
@ -72,7 +72,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||||||
(选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。
|
(选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||||
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。
|
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。<br/>
|
||||||
|
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='trace_id' type='string' key='trace_id'>
|
<Property name='trace_id' type='string' key='trace_id'>
|
||||||
(选填)链路追踪ID。适用于与业务系统已有的trace组件打通,实现端到端分布式追踪等场景。如果未指定,系统会自动生成<code>trace_id</code>。支持以下三种方式传递,具体优先级依次为:<br/>
|
(选填)链路追踪ID。适用于与业务系统已有的trace组件打通,实现端到端分布式追踪等场景。如果未指定,系统会自动生成<code>trace_id</code>。支持以下三种方式传递,具体优先级依次为:<br/>
|
||||||
|
|||||||
@ -344,7 +344,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
### パス
|
### パス
|
||||||
- `workflow_id` (string) 必須 特定バージョンのワークフローを指定するためのワークフローID
|
- `workflow_id` (string) 必須 特定バージョンのワークフローを指定するためのワークフローID
|
||||||
|
|
||||||
取得方法:バージョン履歴で特定バージョンのワークフローIDを照会できます。
|
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||||
|
|
||||||
### リクエストボディ
|
### リクエストボディ
|
||||||
- `inputs` (object) 必須
|
- `inputs` (object) 必須
|
||||||
|
|||||||
@ -334,7 +334,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
|
|||||||
### Path
|
### Path
|
||||||
- `workflow_id` (string) Required 工作流ID,用于指定特定版本的工作流
|
- `workflow_id` (string) Required 工作流ID,用于指定特定版本的工作流
|
||||||
|
|
||||||
获取方式:可以在版本历史中查询特定版本的工作流ID。
|
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||||
|
|
||||||
### Request Body
|
### Request Body
|
||||||
- `inputs` (object) Required
|
- `inputs` (object) Required
|
||||||
|
|||||||
@ -86,7 +86,7 @@ const ModelList: FC<ModelListProps> = ({
|
|||||||
{
|
{
|
||||||
models.map(model => (
|
models.map(model => (
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={`${model.model}-${model.fetch_from}`}
|
key={`${model.model}-${model.model_type}-${model.fetch_from}`}
|
||||||
{...{
|
{...{
|
||||||
model,
|
model,
|
||||||
provider,
|
provider,
|
||||||
|
|||||||
@ -856,6 +856,18 @@
|
|||||||
color: var(--color-prettylights-syntax-comment);
|
color: var(--color-prettylights-syntax-comment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.markdown-body .katex {
|
||||||
|
/* Allow long inline formulas to wrap instead of overflowing */
|
||||||
|
white-space: normal !important;
|
||||||
|
overflow-wrap: break-word; /* better cross-browser support */
|
||||||
|
word-break: break-word; /* non-standard fallback for older WebKit/Blink */
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-body .katex-display {
|
||||||
|
/* Fallback for very long display equations */
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
|
||||||
.markdown-body .pl-c1,
|
.markdown-body .pl-c1,
|
||||||
.markdown-body .pl-s .pl-v {
|
.markdown-body .pl-s .pl-v {
|
||||||
color: var(--color-prettylights-syntax-constant);
|
color: var(--color-prettylights-syntax-constant);
|
||||||
|
|||||||
@ -45,5 +45,118 @@ describe('get-icon', () => {
|
|||||||
const result = getIconFromMarketPlace(pluginId)
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
expect(result).toBe(`${MARKETPLACE_API_PREFIX}/plugins/${pluginId}/icon`)
|
expect(result).toBe(`${MARKETPLACE_API_PREFIX}/plugins/${pluginId}/icon`)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Security tests: Path traversal attempts
|
||||||
|
* These tests document current behavior and potential security concerns
|
||||||
|
* Note: Current implementation does not sanitize path traversal sequences
|
||||||
|
*/
|
||||||
|
test('handles path traversal attempts', () => {
|
||||||
|
const pluginId = '../../../etc/passwd'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
// Current implementation includes path traversal sequences in URL
|
||||||
|
// This is a potential security concern that should be addressed
|
||||||
|
expect(result).toContain('../')
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles multiple path traversal attempts', () => {
|
||||||
|
const pluginId = '../../../../etc/passwd'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
// Current implementation includes path traversal sequences in URL
|
||||||
|
expect(result).toContain('../')
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('passes through URL-encoded path traversal sequences', () => {
|
||||||
|
const pluginId = '..%2F..%2Fetc%2Fpasswd'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Security tests: Null and undefined handling
|
||||||
|
* These tests document current behavior with invalid input types
|
||||||
|
* Note: Current implementation converts null/undefined to strings instead of throwing
|
||||||
|
*/
|
||||||
|
test('handles null plugin ID', () => {
|
||||||
|
// Current implementation converts null to string "null"
|
||||||
|
const result = getIconFromMarketPlace(null as any)
|
||||||
|
expect(result).toContain('null')
|
||||||
|
// This is a potential issue - should validate input type
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles undefined plugin ID', () => {
|
||||||
|
// Current implementation converts undefined to string "undefined"
|
||||||
|
const result = getIconFromMarketPlace(undefined as any)
|
||||||
|
expect(result).toContain('undefined')
|
||||||
|
// This is a potential issue - should validate input type
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Security tests: URL-sensitive characters
|
||||||
|
* These tests verify that URL-sensitive characters are handled appropriately
|
||||||
|
*/
|
||||||
|
test('does not encode URL-sensitive characters', () => {
|
||||||
|
const pluginId = 'plugin/with?special=chars#hash'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
// Note: Current implementation doesn't encode, but test documents the behavior
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
expect(result).toContain('?')
|
||||||
|
expect(result).toContain('#')
|
||||||
|
expect(result).toContain('=')
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles URL characters like & and %', () => {
|
||||||
|
const pluginId = 'plugin&with%encoding'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Edge case tests: Extreme inputs
|
||||||
|
* These tests verify behavior with unusual but valid inputs
|
||||||
|
*/
|
||||||
|
test('handles very long plugin ID', () => {
|
||||||
|
const pluginId = 'a'.repeat(10000)
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
expect(result.length).toBeGreaterThan(10000)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles Unicode characters', () => {
|
||||||
|
const pluginId = '插件-🚀-测试-日本語'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles control characters', () => {
|
||||||
|
const pluginId = 'plugin\nwith\ttabs\r\nand\0null'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Security tests: XSS attempts
|
||||||
|
* These tests verify that XSS attempts are handled appropriately
|
||||||
|
*/
|
||||||
|
test('handles XSS attempts with script tags', () => {
|
||||||
|
const pluginId = '<script>alert("xss")</script>'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
// Note: Current implementation doesn't sanitize, but test documents the behavior
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles XSS attempts with event handlers', () => {
|
||||||
|
const pluginId = 'plugin"onerror="alert(1)"'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('handles XSS attempts with encoded script tags', () => {
|
||||||
|
const pluginId = '%3Cscript%3Ealert%28%22xss%22%29%3C%2Fscript%3E'
|
||||||
|
const result = getIconFromMarketPlace(pluginId)
|
||||||
|
expect(result).toContain(pluginId)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -87,7 +87,8 @@ describe('time', () => {
|
|||||||
test('works with timestamps', () => {
|
test('works with timestamps', () => {
|
||||||
const date = 1705276800000 // 2024-01-15 00:00:00 UTC
|
const date = 1705276800000 // 2024-01-15 00:00:00 UTC
|
||||||
const result = formatTime({ date, dateFormat: 'YYYY-MM-DD' })
|
const result = formatTime({ date, dateFormat: 'YYYY-MM-DD' })
|
||||||
expect(result).toContain('2024-01-1') // Account for timezone differences
|
// Account for timezone differences: UTC-5 to UTC+8 can result in 2024-01-14 or 2024-01-15
|
||||||
|
expect(result).toMatch(/^2024-01-(14|15)$/)
|
||||||
})
|
})
|
||||||
|
|
||||||
test('handles ISO 8601 format', () => {
|
test('handles ISO 8601 format', () => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user