From 68adcfc474c7be16c150edb6c1abdc6cfea40735 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Tue, 14 Apr 2026 15:58:24 +0800 Subject: [PATCH] add migration --- .../vdb/tidb_on_qdrant/tidb_service.py | 298 ++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py new file mode 100644 index 0000000000..68ac02de25 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -0,0 +1,298 @@ +import logging +import time +import uuid +from collections.abc import Sequence + +import httpx +from httpx import DigestAuth + +from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus + +logger = logging.getLogger(__name__) + +# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn +_tidb_http_client: httpx.Client = get_pooled_http_client( + "tidb:cloud", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + + +class TidbService: + @staticmethod + def fetch_qdrant_endpoint( + api_url: str, public_key: str, private_key: str, cluster_id: str + ) -> str | None: + """Fetch the qdrant endpoint for a cluster by calling the Get Cluster API. + + The Get Cluster response contains ``status.connection_strings.standard.host`` + (e.g. ``gateway01.xx.tidbcloud.com``). We prepend ``qdrant-`` and wrap it + as an ``https://`` URL. + """ + try: + cluster_response = TidbService.get_tidb_serverless_cluster( + api_url, public_key, private_key, cluster_id + ) + if not cluster_response: + return None + # v1beta: status.connection_strings.standard.host + status = cluster_response.get("status") or {} + connection_strings = status.get("connection_strings") or {} + standard = connection_strings.get("standard") or {} + host = standard.get("host") + if host: + return f"https://qdrant-{host}" + except Exception: + logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) + return None + + @staticmethod + def create_tidb_serverless_cluster( + project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ): + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": dify_config.TIDB_SPEND_LIMIT, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "")[:16] + cluster_data = { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + + response = _tidb_http_client.post( + f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_id = response_data["clusterId"] + retry_count = 0 + max_retries = 30 + while retry_count < max_retries: + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if cluster_response["state"] == "ACTIVE": + user_prefix = cluster_response["userPrefix"] + qdrant_endpoint = TidbService.fetch_qdrant_endpoint( + api_url, public_key, private_key, cluster_id + ) + return { + "cluster_id": cluster_id, + "cluster_name": display_name, + "account": f"{user_prefix}.root", + "password": password, + "qdrant_endpoint": qdrant_endpoint, + } + time.sleep(30) # wait 30 seconds before retrying + retry_count += 1 + else: + response.raise_for_status() + + @staticmethod + def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = _tidb_http_client.delete( + f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def change_tidb_serverless_root_password( + api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str + ): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ + :param account: The account for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} + + response = _tidb_http_client.patch( + f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", + json=body, + auth=DigestAuth(public_key, private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def batch_update_tidb_serverless_cluster_status( + tidb_serverless_list: Sequence[TidbAuthBinding], + project_id: str, + api_url: str, + iam_url: str, + public_key: str, + private_key: str, + ): + """ + Update the status of a new TiDB Serverless cluster. + :param tidb_serverless_list: The TiDB serverless list (required). + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + + :return: The response from the API. + """ + tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} + cluster_ids = [item.cluster_id for item in tidb_serverless_list] + params = {"clusterIds": cluster_ids, "view": "BASIC"} + response = _tidb_http_client.get( + f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + for item in response_data["clusters"]: + state = item["state"] + userPrefix = item["userPrefix"] + if state == "ACTIVE" and len(userPrefix) > 0: + cluster_info = tidb_serverless_list_map[item["clusterId"]] + cluster_info.status = TidbAuthBindingStatus.ACTIVE + cluster_info.account = f"{userPrefix}.root" + db.session.add(cluster_info) + db.session.commit() + else: + response.raise_for_status() + + @staticmethod + def batch_create_tidb_serverless_cluster( + batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ) -> list[dict]: + """ + Creates a new TiDB Serverless cluster. + :param batch_size: The batch size (required). + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + for _ in range(batch_size): + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": dify_config.TIDB_SPEND_LIMIT, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "") + cluster_data = { + "cluster": { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + } + cache_key = f"tidb_serverless_cluster_password:{display_name}" + redis_client.setex(cache_key, 3600, password) + clusters.append(cluster_data) + + request_body = {"requests": clusters} + response = _tidb_http_client.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" + cached_password = redis_client.get(cache_key) + if not cached_password: + continue + cluster_info = { + "cluster_id": item["clusterId"], + "cluster_name": item["displayName"], + "account": "root", + "password": cached_password.decode("utf-8"), + "qdrant_endpoint": TidbService.fetch_qdrant_endpoint( + api_url, public_key, private_key, item["clusterId"] + ), + } + cluster_infos.append(cluster_info) + return cluster_infos + else: + response.raise_for_status() + return []