diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py deleted file mode 100644 index 40f6794af2..0000000000 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ /dev/null @@ -1,324 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Sequence - -import httpx -from httpx import DigestAuth - -from configs import dify_config -from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import TidbAuthBinding -from models.enums import TidbAuthBindingStatus - -logger = logging.getLogger(__name__) - -# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn -_tidb_http_client: httpx.Client = get_pooled_http_client( - "tidb:cloud", - lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), -) - - -class TidbService: - @staticmethod - def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: - """Fetch the qdrant endpoint for a cluster by calling the Get Cluster API. - - The v1beta1 serverless Get Cluster response contains - ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``). - We prepend ``qdrant-`` and wrap it as an ``https://`` URL. - """ - try: - logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) - cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) - if not cluster_response: - logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) - return None - # v1beta1 serverless: endpoints.public.host - endpoints = cluster_response.get("endpoints") or {} - public = endpoints.get("public") or {} - host = public.get("host") - if host: - qdrant_url = f"https://qdrant-{host}" - logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) - return qdrant_url - logger.warning( - "No endpoints.public.host found for cluster %s, response keys: %s", - cluster_id, - list(cluster_response.keys()), - ) - except Exception: - logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) - return None - - @staticmethod - def create_tidb_serverless_cluster( - project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str - ): - """ - Creates a new TiDB Serverless cluster. - :param project_id: The project ID of the TiDB Cloud project (required). - :param api_url: The URL of the TiDB Cloud API (required). - :param iam_url: The URL of the TiDB Cloud IAM API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - :param region: The region where the cluster will be created (required). - - :return: The response from the API. - """ - - region_object = { - "name": region, - } - - labels = { - "tidb.cloud/project": project_id, - } - - spending_limit = { - "monthly": dify_config.TIDB_SPEND_LIMIT, - } - password = str(uuid.uuid4()).replace("-", "")[:16] - display_name = str(uuid.uuid4()).replace("-", "")[:16] - cluster_data = { - "displayName": display_name, - "region": region_object, - "labels": labels, - "spendingLimit": spending_limit, - "rootPassword": password, - } - - logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region) - response = _tidb_http_client.post( - f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) - ) - - if response.status_code == 200: - response_data = response.json() - cluster_id = response_data["clusterId"] - logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id) - retry_count = 0 - max_retries = 30 - while retry_count < max_retries: - cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) - if cluster_response["state"] == "ACTIVE": - user_prefix = cluster_response["userPrefix"] - qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, cluster_id) - logger.info( - "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", - cluster_id, - user_prefix, - qdrant_endpoint, - ) - return { - "cluster_id": cluster_id, - "cluster_name": display_name, - "account": f"{user_prefix}.root", - "password": password, - "qdrant_endpoint": qdrant_endpoint, - } - logger.info( - "Cluster %s state=%s, retry %d/%d", - cluster_id, - cluster_response["state"], - retry_count + 1, - max_retries, - ) - time.sleep(30) - retry_count += 1 - logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) - else: - logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text) - response.raise_for_status() - - @staticmethod - def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): - """ - Deletes a specific TiDB Serverless cluster. - - :param api_url: The URL of the TiDB Cloud API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - :param cluster_id: The ID of the cluster to be deleted (required). - :return: The response from the API. - """ - - response = _tidb_http_client.delete( - f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key) - ) - - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - - @staticmethod - def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): - """ - Deletes a specific TiDB Serverless cluster. - - :param api_url: The URL of the TiDB Cloud API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - :param cluster_id: The ID of the cluster to be deleted (required). - :return: The response from the API. - """ - - response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) - - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - - @staticmethod - def change_tidb_serverless_root_password( - api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str - ): - """ - Changes the root password of a specific TiDB Serverless cluster. - - :param api_url: The URL of the TiDB Cloud API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ - :param account: The account for which the password is to be changed (required). - :param new_password: The new password for the root user (required). - :return: The response from the API. - """ - - body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} - - response = _tidb_http_client.patch( - f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", - json=body, - auth=DigestAuth(public_key, private_key), - ) - - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - - @staticmethod - def batch_update_tidb_serverless_cluster_status( - tidb_serverless_list: Sequence[TidbAuthBinding], - project_id: str, - api_url: str, - iam_url: str, - public_key: str, - private_key: str, - ): - """ - Update the status of a new TiDB Serverless cluster. - :param tidb_serverless_list: The TiDB serverless list (required). - :param project_id: The project ID of the TiDB Cloud project (required). - :param api_url: The URL of the TiDB Cloud API (required). - :param iam_url: The URL of the TiDB Cloud IAM API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - - :return: The response from the API. - """ - tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} - cluster_ids = [item.cluster_id for item in tidb_serverless_list] - params = {"clusterIds": cluster_ids, "view": "BASIC"} - response = _tidb_http_client.get( - f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key) - ) - - if response.status_code == 200: - response_data = response.json() - for item in response_data["clusters"]: - state = item["state"] - userPrefix = item["userPrefix"] - if state == "ACTIVE" and len(userPrefix) > 0: - cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = TidbAuthBindingStatus.ACTIVE - cluster_info.account = f"{userPrefix}.root" - db.session.add(cluster_info) - db.session.commit() - else: - response.raise_for_status() - - @staticmethod - def batch_create_tidb_serverless_cluster( - batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str - ) -> list[dict]: - """ - Creates a new TiDB Serverless cluster. - :param batch_size: The batch size (required). - :param project_id: The project ID of the TiDB Cloud project (required). - :param api_url: The URL of the TiDB Cloud API (required). - :param iam_url: The URL of the TiDB Cloud IAM API (required). - :param public_key: The public key for the API (required). - :param private_key: The private key for the API (required). - :param region: The region where the cluster will be created (required). - - :return: The response from the API. - """ - clusters = [] - for _ in range(batch_size): - region_object = { - "name": region, - } - - labels = { - "tidb.cloud/project": project_id, - } - - spending_limit = { - "monthly": dify_config.TIDB_SPEND_LIMIT, - } - password = str(uuid.uuid4()).replace("-", "")[:16] - display_name = str(uuid.uuid4()).replace("-", "") - cluster_data = { - "cluster": { - "displayName": display_name, - "region": region_object, - "labels": labels, - "spendingLimit": spending_limit, - "rootPassword": password, - } - } - cache_key = f"tidb_serverless_cluster_password:{display_name}" - redis_client.setex(cache_key, 3600, password) - clusters.append(cluster_data) - - request_body = {"requests": clusters} - response = _tidb_http_client.post( - f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) - ) - - if response.status_code == 200: - response_data = response.json() - cluster_infos = [] - logger.info("Batch created %d clusters", len(response_data.get("clusters", []))) - for item in response_data["clusters"]: - cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - cached_password = redis_client.get(cache_key) - if not cached_password: - logger.warning("No cached password for cluster %s, skipping", item["displayName"]) - continue - qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) - logger.info( - "Batch cluster %s: qdrant_endpoint=%s", - item["clusterId"], - qdrant_endpoint, - ) - cluster_info = { - "cluster_id": item["clusterId"], - "cluster_name": item["displayName"], - "account": "root", - "password": cached_password.decode("utf-8"), - "qdrant_endpoint": qdrant_endpoint, - } - cluster_infos.append(cluster_info) - return cluster_infos - else: - logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text) - response.raise_for_status() - return [] diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index 8b425164e4..ece061db67 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -24,12 +24,25 @@ _tidb_http_client: httpx.Client = get_pooled_http_client( class TidbService: @staticmethod - def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: - """Fetch the qdrant endpoint for a cluster by calling the Get Cluster API. + def extract_qdrant_endpoint(cluster_response: dict) -> str | None: + """Extract the qdrant endpoint URL from a Get Cluster API response. - The v1beta1 serverless Get Cluster response contains - ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``). - We prepend ``qdrant-`` and wrap it as an ``https://`` URL. + Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``), + prepends ``qdrant-`` and wraps it as an ``https://`` URL. + """ + endpoints = cluster_response.get("endpoints") or {} + public = endpoints.get("public") or {} + host = public.get("host") + if host: + return f"https://qdrant-{host}" + return None + + @staticmethod + def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: + """Call Get Cluster API and extract the qdrant endpoint. + + Use ``extract_qdrant_endpoint`` instead when you already have + the cluster response to avoid a redundant API call. """ try: logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) @@ -37,12 +50,8 @@ class TidbService: if not cluster_response: logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) return None - # v1beta1 serverless: endpoints.public.host - endpoints = cluster_response.get("endpoints") or {} - public = endpoints.get("public") or {} - host = public.get("host") - if host: - qdrant_url = f"https://qdrant-{host}" + qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response) + if qdrant_url: logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) return qdrant_url logger.warning( @@ -106,10 +115,12 @@ class TidbService: cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) if cluster_response["state"] == "ACTIVE": user_prefix = cluster_response["userPrefix"] - qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, cluster_id) + qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response) logger.info( "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", - cluster_id, user_prefix, qdrant_endpoint, + cluster_id, + user_prefix, + qdrant_endpoint, ) return { "cluster_id": cluster_id, @@ -118,7 +129,13 @@ class TidbService: "password": password, "qdrant_endpoint": qdrant_endpoint, } - logger.info("Cluster %s state=%s, retry %d/%d", cluster_id, cluster_response["state"], retry_count + 1, max_retries) + logger.info( + "Cluster %s state=%s, retry %d/%d", + cluster_id, + cluster_response["state"], + retry_count + 1, + max_retries, + ) time.sleep(30) retry_count += 1 logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) @@ -295,12 +312,11 @@ class TidbService: if not cached_password: logger.warning("No cached password for cluster %s, skipping", item["displayName"]) continue - qdrant_endpoint = TidbService.fetch_qdrant_endpoint( - api_url, public_key, private_key, item["clusterId"] - ) + qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) logger.info( "Batch cluster %s: qdrant_endpoint=%s", - item["clusterId"], qdrant_endpoint, + item["clusterId"], + qdrant_endpoint, ) cluster_info = { "cluster_id": item["clusterId"], diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index e4fca9f931..76802de62e 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -172,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint. + + We avoid importing the full module (which triggers Flask app context) + by testing the endpoint selection logic directly on TidbOnQdrantConfig. + """ + + def test_uses_binding_endpoint_when_present(self): + binding_endpoint = "https://qdrant-custom.tidb.com" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-custom.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-custom.tidb.com" + + def test_falls_back_to_global_when_binding_endpoint_is_none(self): + binding_endpoint = None + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-global.tidb.com" + + def test_falls_back_to_empty_when_both_none(self): + binding_endpoint = None + global_url = None + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "" + + def test_binding_endpoint_takes_precedence_over_global(self): + binding_endpoint = "https://qdrant-ap-southeast.tidb.com" + global_url = "https://qdrant-us-east.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-ap-southeast.tidb.com" + + def test_empty_string_binding_endpoint_falls_back_to_global(self): + binding_endpoint = "" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py index e8a013efca..c1ffbacbbc 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -1,15 +1,36 @@ from unittest.mock import MagicMock, patch +import pytest from dify_vdb_tidb_on_qdrant.tidb_service import TidbService +class TestExtractQdrantEndpoint: + """Unit tests for TidbService.extract_qdrant_endpoint.""" + + def test_returns_endpoint_when_host_present(self): + response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}} + result = TidbService.extract_qdrant_endpoint(response) + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + def test_returns_none_when_host_missing(self): + response = {"endpoints": {"public": {}}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_public_missing(self): + response = {"endpoints": {}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_endpoints_missing(self): + assert TidbService.extract_qdrant_endpoint({}) is None + + class TestFetchQdrantEndpoint: """Unit tests for TidbService.fetch_qdrant_endpoint.""" @patch.object(TidbService, "get_tidb_serverless_cluster") def test_returns_endpoint_when_host_present(self, mock_get_cluster): mock_get_cluster.return_value = { - "status": {"connection_strings": {"standard": {"host": "gateway01.us-east-1.tidbcloud.com"}}} + "endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}} } result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" @@ -17,65 +38,48 @@ class TestFetchQdrantEndpoint: @patch.object(TidbService, "get_tidb_serverless_cluster") def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster): mock_get_cluster.return_value = None - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None @patch.object(TidbService, "get_tidb_serverless_cluster") def test_returns_none_when_host_missing(self, mock_get_cluster): - mock_get_cluster.return_value = {"status": {"connection_strings": {"standard": {}}}} - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None + mock_get_cluster.return_value = {"endpoints": {"public": {}}} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None @patch.object(TidbService, "get_tidb_serverless_cluster") - def test_returns_none_when_status_missing(self, mock_get_cluster): + def test_returns_none_when_endpoints_missing(self, mock_get_cluster): mock_get_cluster.return_value = {} - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None - - @patch.object(TidbService, "get_tidb_serverless_cluster") - def test_returns_none_when_connection_strings_missing(self, mock_get_cluster): - mock_get_cluster.return_value = {"status": {}} - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None @patch.object(TidbService, "get_tidb_serverless_cluster") def test_returns_none_on_exception(self, mock_get_cluster): mock_get_cluster.side_effect = RuntimeError("network error") - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None - - @patch.object(TidbService, "get_tidb_serverless_cluster") - def test_returns_none_when_standard_key_missing(self, mock_get_cluster): - mock_get_cluster.return_value = {"status": {"connection_strings": {}}} - result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") - assert result is None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None class TestCreateTidbServerlessClusterQdrantEndpoint: """Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result.""" - @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") @patch.object(TidbService, "get_tidb_serverless_cluster") @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") - def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep): + def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster): mock_config.TIDB_SPEND_LIMIT = 10 mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) - mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} + mock_get_cluster.return_value = { + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}}, + } result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") assert result is not None assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" - mock_fetch_ep.assert_called_once_with("url", "pub", "priv", "c-1") - @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) @patch.object(TidbService, "get_tidb_serverless_cluster") @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") - def test_result_qdrant_endpoint_none_when_fetch_fails( - self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep - ): + def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster): mock_config.TIDB_SPEND_LIMIT = 10 mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} @@ -115,3 +119,100 @@ class TestBatchCreateTidbServerlessClusterQdrantEndpoint: assert len(result) == 1 assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + )