mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
feat: tidb endpoint (#35158)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
98897a5379
commit
9fd196642d
@ -0,0 +1,26 @@
|
||||
"""add qdrant_endpoint to tidb_auth_bindings
|
||||
|
||||
Revision ID: 8574b23a38fd
|
||||
Revises: 6b5f9f8b1a2c
|
||||
Create Date: 2026-04-14 15:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8574b23a38fd"
|
||||
down_revision = "6b5f9f8b1a2c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.drop_column("qdrant_endpoint")
|
||||
@ -1305,6 +1305,7 @@ class TidbAuthBinding(TypeBase):
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
import qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from flask import current_app
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel
|
||||
@ -421,13 +424,16 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id)
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
@ -437,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
.limit(1)
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
logger.info(
|
||||
"Assigning idle cluster %s to tenant %s",
|
||||
idle_tidb_auth_binding.cluster_id,
|
||||
dataset.tenant_id,
|
||||
)
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
tidb_auth_binding = idle_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id)
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
dify_config.TIDB_API_URL or "",
|
||||
@ -450,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
dify_config.TIDB_PRIVATE_KEY or "",
|
||||
dify_config.TIDB_REGION or "",
|
||||
)
|
||||
logger.info(
|
||||
"New cluster created: cluster_id=%s, qdrant_endpoint=%s",
|
||||
new_cluster["cluster_id"],
|
||||
new_cluster.get("qdrant_endpoint"),
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
tidb_auth_binding = new_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
qdrant_url = (
|
||||
(tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or ""
|
||||
)
|
||||
logger.info(
|
||||
"Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)",
|
||||
qdrant_url,
|
||||
tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None,
|
||||
dify_config.TIDB_ON_QDRANT_URL,
|
||||
)
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -479,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
|
||||
endpoint=qdrant_url,
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=str(config.root_path),
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
|
||||
_tidb_http_client: httpx.Client = get_pooled_http_client(
|
||||
"tidb:cloud",
|
||||
@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client(
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
|
||||
"""Extract the qdrant endpoint URL from a Get Cluster API response.
|
||||
|
||||
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
|
||||
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
|
||||
"""
|
||||
endpoints = cluster_response.get("endpoints") or {}
|
||||
public = endpoints.get("public") or {}
|
||||
host = public.get("host")
|
||||
if host:
|
||||
return f"https://qdrant-{host}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
|
||||
"""Call Get Cluster API and extract the qdrant endpoint.
|
||||
|
||||
Use ``extract_qdrant_endpoint`` instead when you already have
|
||||
the cluster response to avoid a redundant API call.
|
||||
"""
|
||||
try:
|
||||
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if not cluster_response:
|
||||
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
|
||||
return None
|
||||
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
if qdrant_url:
|
||||
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
|
||||
return qdrant_url
|
||||
logger.warning(
|
||||
"No endpoints.public.host found for cluster %s, response keys: %s",
|
||||
cluster_id,
|
||||
list(cluster_response.keys()),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_tidb_serverless_cluster(
|
||||
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
@ -57,6 +100,7 @@ class TidbService:
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
|
||||
response = _tidb_http_client.post(
|
||||
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
@ -64,21 +108,39 @@ class TidbService:
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_id = response_data["clusterId"]
|
||||
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
|
||||
retry_count = 0
|
||||
max_retries = 30
|
||||
while retry_count < max_retries:
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if cluster_response["state"] == "ACTIVE":
|
||||
user_prefix = cluster_response["userPrefix"]
|
||||
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
logger.info(
|
||||
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
|
||||
cluster_id,
|
||||
user_prefix,
|
||||
qdrant_endpoint,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"cluster_name": display_name,
|
||||
"account": f"{user_prefix}.root",
|
||||
"password": password,
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
time.sleep(30) # wait 30 seconds before retrying
|
||||
logger.info(
|
||||
"Cluster %s state=%s, retry %d/%d",
|
||||
cluster_id,
|
||||
cluster_response["state"],
|
||||
retry_count + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(30)
|
||||
retry_count += 1
|
||||
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
|
||||
else:
|
||||
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
@ -243,19 +305,29 @@ class TidbService:
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
cached_password = redis_client.get(cache_key)
|
||||
if not cached_password:
|
||||
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
|
||||
continue
|
||||
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
|
||||
logger.info(
|
||||
"Batch cluster %s: qdrant_endpoint=%s",
|
||||
item["clusterId"],
|
||||
qdrant_endpoint,
|
||||
)
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": cached_password.decode("utf-8"),
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
return []
|
||||
|
||||
@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
|
||||
"""Test deletion with exactly 1000 IDs triggers a single batch."""
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_splits_into_batches(self, vector_instance):
|
||||
"""Test deletion with >1000 IDs triggers multiple batched calls."""
|
||||
ids = [f"doc_{i}" for i in range(2500)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert vector_instance._client.delete.call_count == 3
|
||||
|
||||
batches = []
|
||||
for call in vector_instance._client.delete.call_args_list:
|
||||
filter_selector = call[1]["points_selector"]
|
||||
field_condition = filter_selector.filter.must[0]
|
||||
batches.append(field_condition.match.any)
|
||||
|
||||
assert len(batches[0]) == 1000
|
||||
assert len(batches[1]) == 1000
|
||||
assert len(batches[2]) == 500
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
@ -157,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
|
||||
|
||||
class TestInitVectorEndpointSelection:
|
||||
"""Test that init_vector selects the correct qdrant endpoint.
|
||||
|
||||
We avoid importing the full module (which triggers Flask app context)
|
||||
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
|
||||
"""
|
||||
|
||||
def test_uses_binding_endpoint_when_present(self):
|
||||
binding_endpoint = "https://qdrant-custom.tidb.com"
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-custom.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-custom.tidb.com"
|
||||
|
||||
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-global.tidb.com"
|
||||
|
||||
def test_falls_back_to_empty_when_both_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = None
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == ""
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == ""
|
||||
|
||||
def test_binding_endpoint_takes_precedence_over_global(self):
|
||||
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
|
||||
global_url = "https://qdrant-us-east.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
|
||||
|
||||
def test_empty_string_binding_endpoint_falls_back_to_global(self):
|
||||
binding_endpoint = ""
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
|
||||
@ -0,0 +1,218 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
|
||||
|
||||
|
||||
class TestExtractQdrantEndpoint:
|
||||
"""Unit tests for TidbService.extract_qdrant_endpoint."""
|
||||
|
||||
def test_returns_endpoint_when_host_present(self):
|
||||
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
|
||||
result = TidbService.extract_qdrant_endpoint(response)
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
def test_returns_none_when_host_missing(self):
|
||||
response = {"endpoints": {"public": {}}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_public_missing(self):
|
||||
response = {"endpoints": {}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_endpoints_missing(self):
|
||||
assert TidbService.extract_qdrant_endpoint({}) is None
|
||||
|
||||
|
||||
class TestFetchQdrantEndpoint:
|
||||
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {
|
||||
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
|
||||
}
|
||||
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = None
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_host_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_on_exception(self, mock_get_cluster):
|
||||
mock_get_cluster.side_effect = RuntimeError("network error")
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {
|
||||
"state": "ACTIVE",
|
||||
"userPrefix": "pfx",
|
||||
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
|
||||
}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] is None
|
||||
|
||||
|
||||
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that batch_create includes qdrant_endpoint per cluster."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
cluster_name = "abc123"
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = b"password123"
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterRetry:
|
||||
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.side_effect = [
|
||||
{"state": "CREATING", "userPrefix": ""},
|
||||
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
|
||||
]
|
||||
|
||||
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
|
||||
assert mock_get_cluster.call_count == 2
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
|
||||
|
||||
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=400, text="Bad Request")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
|
||||
mock_http.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 400"):
|
||||
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
|
||||
class TestBatchCreateEdgeCases:
|
||||
"""Cover logging/edge-case branches in batch_create."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
mock_fetch_ep.assert_not_called()
|
||||
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=500, text="Server Error")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
|
||||
mock_http.post.return_value = mock_response
|
||||
mock_redis.setex = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
@ -57,6 +57,7 @@ def create_clusters(batch_size):
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
active=False,
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user