diff --git a/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py new file mode 100644 index 0000000000..0e188ec080 --- /dev/null +++ b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py @@ -0,0 +1,26 @@ +"""add qdrant_endpoint to tidb_auth_bindings + +Revision ID: 8574b23a38fd +Revises: 6b5f9f8b1a2c +Create Date: 2026-04-14 15:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8574b23a38fd" +down_revision = "6b5f9f8b1a2c" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True)) + + +def downgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.drop_column("qdrant_endpoint") diff --git a/api/models/dataset.py b/api/models/dataset.py index 4540c29206..50301dd2d7 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1305,6 +1305,7 @@ class TidbAuthBinding(TypeBase): ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) + qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py index bb8a580ebf..abca55f540 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -1,4 +1,5 @@ import json +import logging import os import uuid from collections.abc import Generator, Iterable, Sequence @@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any import httpx import qdrant_client + +logger = logging.getLogger(__name__) from flask import current_app from httpx import DigestAuth from pydantic import BaseModel @@ -421,13 +424,16 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id) stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: + logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id) with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: + logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: @@ -437,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): .limit(1) ) if idle_tidb_auth_binding: + logger.info( + "Assigning idle cluster %s to tenant %s", + idle_tidb_auth_binding.cluster_id, + dataset.tenant_id, + ) idle_tidb_auth_binding.active = True idle_tidb_auth_binding.tenant_id = dataset.tenant_id db.session.commit() + tidb_auth_binding = idle_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" else: + logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id) new_cluster = TidbService.create_tidb_serverless_cluster( dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_API_URL or "", @@ -450,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): dify_config.TIDB_PRIVATE_KEY or "", dify_config.TIDB_REGION or "", ) + logger.info( + "New cluster created: cluster_id=%s, qdrant_endpoint=%s", + new_cluster["cluster_id"], + new_cluster.get("qdrant_endpoint"), + ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), tenant_id=dataset.tenant_id, active=True, status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() + tidb_auth_binding = new_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" else: + logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + qdrant_url = ( + (tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or "" + ) + logger.info( + "Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)", + qdrant_url, + tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None, + dify_config.TIDB_ON_QDRANT_URL, + ) + if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix @@ -479,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL or "", + endpoint=qdrant_url, api_key=TIDB_ON_QDRANT_API_KEY, root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index 37114be6e7..ece061db67 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -1,3 +1,4 @@ +import logging import time import uuid from collections.abc import Sequence @@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +logger = logging.getLogger(__name__) + # Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn _tidb_http_client: httpx.Client = get_pooled_http_client( "tidb:cloud", @@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client( class TidbService: + @staticmethod + def extract_qdrant_endpoint(cluster_response: dict) -> str | None: + """Extract the qdrant endpoint URL from a Get Cluster API response. + + Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``), + prepends ``qdrant-`` and wraps it as an ``https://`` URL. + """ + endpoints = cluster_response.get("endpoints") or {} + public = endpoints.get("public") or {} + host = public.get("host") + if host: + return f"https://qdrant-{host}" + return None + + @staticmethod + def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: + """Call Get Cluster API and extract the qdrant endpoint. + + Use ``extract_qdrant_endpoint`` instead when you already have + the cluster response to avoid a redundant API call. + """ + try: + logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if not cluster_response: + logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) + return None + qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response) + if qdrant_url: + logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) + return qdrant_url + logger.warning( + "No endpoints.public.host found for cluster %s, response keys: %s", + cluster_id, + list(cluster_response.keys()), + ) + except Exception: + logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) + return None + @staticmethod def create_tidb_serverless_cluster( project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str @@ -57,6 +100,7 @@ class TidbService: "rootPassword": password, } + logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region) response = _tidb_http_client.post( f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) ) @@ -64,21 +108,39 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_id = response_data["clusterId"] + logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id) retry_count = 0 max_retries = 30 while retry_count < max_retries: cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) if cluster_response["state"] == "ACTIVE": user_prefix = cluster_response["userPrefix"] + qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response) + logger.info( + "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", + cluster_id, + user_prefix, + qdrant_endpoint, + ) return { "cluster_id": cluster_id, "cluster_name": display_name, "account": f"{user_prefix}.root", "password": password, + "qdrant_endpoint": qdrant_endpoint, } - time.sleep(30) # wait 30 seconds before retrying + logger.info( + "Cluster %s state=%s, retry %d/%d", + cluster_id, + cluster_response["state"], + retry_count + 1, + max_retries, + ) + time.sleep(30) retry_count += 1 + logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) else: + logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() @staticmethod @@ -243,19 +305,29 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_infos = [] + logger.info("Batch created %d clusters", len(response_data.get("clusters", []))) for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" cached_password = redis_client.get(cache_key) if not cached_password: + logger.warning("No cached password for cluster %s, skipping", item["displayName"]) continue + qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + logger.info( + "Batch cluster %s: qdrant_endpoint=%s", + item["clusterId"], + qdrant_endpoint, + ) cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", "password": cached_password.decode("utf-8"), + "qdrant_endpoint": qdrant_endpoint, } cluster_infos.append(cluster_info) return cluster_infos else: + logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() return [] diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index 3e9229fea5..76802de62e 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds: assert exc_info.value.status_code == 500 - def test_delete_by_ids_with_large_batch(self, vector_instance): - """Test deletion with a large batch of IDs.""" - # Create 1000 IDs + def test_delete_by_ids_with_exactly_1000(self, vector_instance): + """Test deletion with exactly 1000 IDs triggers a single batch.""" ids = [f"doc_{i}" for i in range(1000)] vector_instance.delete_by_ids(ids) - # Verify single delete call with all IDs vector_instance._client.delete.assert_called_once() call_args = vector_instance._client.delete.call_args @@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds: filter_obj = filter_selector.filter field_condition = filter_obj.must[0] - # Verify all 1000 IDs are in the batch assert len(field_condition.match.any) == 1000 assert "doc_0" in field_condition.match.any assert "doc_999" in field_condition.match.any + def test_delete_by_ids_splits_into_batches(self, vector_instance): + """Test deletion with >1000 IDs triggers multiple batched calls.""" + ids = [f"doc_{i}" for i in range(2500)] + + vector_instance.delete_by_ids(ids) + + assert vector_instance._client.delete.call_count == 3 + + batches = [] + for call in vector_instance._client.delete.call_args_list: + filter_selector = call[1]["points_selector"] + field_condition = filter_selector.filter.must[0] + batches.append(field_condition.match.any) + + assert len(batches[0]) == 1000 + assert len(batches[1]) == 1000 + assert len(batches[2]) == 500 + def test_delete_by_ids_filter_structure(self, vector_instance): """Test that the filter structure is correctly constructed.""" ids = ["doc1", "doc2"] @@ -157,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint. + + We avoid importing the full module (which triggers Flask app context) + by testing the endpoint selection logic directly on TidbOnQdrantConfig. + """ + + def test_uses_binding_endpoint_when_present(self): + binding_endpoint = "https://qdrant-custom.tidb.com" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-custom.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-custom.tidb.com" + + def test_falls_back_to_global_when_binding_endpoint_is_none(self): + binding_endpoint = None + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-global.tidb.com" + + def test_falls_back_to_empty_when_both_none(self): + binding_endpoint = None + global_url = None + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "" + + def test_binding_endpoint_takes_precedence_over_global(self): + binding_endpoint = "https://qdrant-ap-southeast.tidb.com" + global_url = "https://qdrant-us-east.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-ap-southeast.tidb.com" + + def test_empty_string_binding_endpoint_falls_back_to_global(self): + binding_endpoint = "" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py new file mode 100644 index 0000000000..c1ffbacbbc --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -0,0 +1,218 @@ +from unittest.mock import MagicMock, patch + +import pytest +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService + + +class TestExtractQdrantEndpoint: + """Unit tests for TidbService.extract_qdrant_endpoint.""" + + def test_returns_endpoint_when_host_present(self): + response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}} + result = TidbService.extract_qdrant_endpoint(response) + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + def test_returns_none_when_host_missing(self): + response = {"endpoints": {"public": {}}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_public_missing(self): + response = {"endpoints": {}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_endpoints_missing(self): + assert TidbService.extract_qdrant_endpoint({}) is None + + +class TestFetchQdrantEndpoint: + """Unit tests for TidbService.fetch_qdrant_endpoint.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_endpoint_when_host_present(self, mock_get_cluster): + mock_get_cluster.return_value = { + "endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}} + } + result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster): + mock_get_cluster.return_value = None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_host_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {"endpoints": {"public": {}}} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_endpoints_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_on_exception(self, mock_get_cluster): + mock_get_cluster.side_effect = RuntimeError("network error") + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + +class TestCreateTidbServerlessClusterQdrantEndpoint: + """Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = { + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}}, + } + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] is None + + +class TestBatchCreateTidbServerlessClusterQdrantEndpoint: + """Verify that batch_create includes qdrant_endpoint per cluster.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + cluster_name = "abc123" + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = b"password123" + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 1 + assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index c4c203c150..e242b0c667 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -57,6 +57,7 @@ def create_clusters(batch_size): cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), active=False, status=TidbAuthBindingStatus.CREATING, )