mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
fix unit test
This commit is contained in:
parent
b9a9c8267e
commit
43e4e161b5
@ -1,10 +1,11 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector import (
|
from dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector import (
|
||||||
TidbOnQdrantConfig,
|
TidbOnQdrantConfig,
|
||||||
TidbOnQdrantVector,
|
TidbOnQdrantVector,
|
||||||
|
TidbOnQdrantVectorFactory,
|
||||||
)
|
)
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
@ -172,3 +173,62 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
|||||||
# Verify MatchAny structure
|
# Verify MatchAny structure
|
||||||
assert isinstance(field_condition.match, rest.MatchAny)
|
assert isinstance(field_condition.match, rest.MatchAny)
|
||||||
assert field_condition.match.any == ids
|
assert field_condition.match.any == ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitVectorEndpointSelection:
|
||||||
|
"""Test that init_vector selects the correct qdrant endpoint."""
|
||||||
|
|
||||||
|
def _make_dataset(self, tenant_id="t-1", dataset_id="d-1", index_struct_dict=None):
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.tenant_id = tenant_id
|
||||||
|
ds.id = dataset_id
|
||||||
|
ds.index_struct_dict = index_struct_dict
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def _make_binding(self, account="acc", password="pwd", qdrant_endpoint=None, cluster_id="c-1"):
|
||||||
|
b = MagicMock()
|
||||||
|
b.account = account
|
||||||
|
b.password = password
|
||||||
|
b.qdrant_endpoint = qdrant_endpoint
|
||||||
|
b.cluster_id = cluster_id
|
||||||
|
return b
|
||||||
|
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.current_app")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.dify_config")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.db")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient")
|
||||||
|
def test_uses_binding_endpoint_when_present(self, mock_qc, mock_db, mock_config, mock_app):
|
||||||
|
binding = self._make_binding(qdrant_endpoint="https://qdrant-custom.tidb.com")
|
||||||
|
mock_db.session.scalars.return_value.one_or_none.return_value = binding
|
||||||
|
mock_config.TIDB_ON_QDRANT_URL = "https://qdrant-global.tidb.com"
|
||||||
|
mock_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT = 20
|
||||||
|
mock_config.TIDB_ON_QDRANT_GRPC_PORT = 6334
|
||||||
|
mock_config.TIDB_ON_QDRANT_GRPC_ENABLED = False
|
||||||
|
mock_config.QDRANT_REPLICATION_FACTOR = 1
|
||||||
|
mock_app.config = {"root_path": "/app"}
|
||||||
|
|
||||||
|
ds = self._make_dataset(index_struct_dict={"type": "tidb_on_qdrant", "vector_store": {"class_prefix": "col"}})
|
||||||
|
factory = TidbOnQdrantVectorFactory()
|
||||||
|
result = factory.init_vector(ds, [], MagicMock())
|
||||||
|
|
||||||
|
assert result._client_config.endpoint == "https://qdrant-custom.tidb.com"
|
||||||
|
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.current_app")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.dify_config")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.db")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient")
|
||||||
|
def test_falls_back_to_global_when_binding_endpoint_is_none(self, mock_qc, mock_db, mock_config, mock_app):
|
||||||
|
binding = self._make_binding(qdrant_endpoint=None)
|
||||||
|
mock_db.session.scalars.return_value.one_or_none.return_value = binding
|
||||||
|
mock_config.TIDB_ON_QDRANT_URL = "https://qdrant-global.tidb.com"
|
||||||
|
mock_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT = 20
|
||||||
|
mock_config.TIDB_ON_QDRANT_GRPC_PORT = 6334
|
||||||
|
mock_config.TIDB_ON_QDRANT_GRPC_ENABLED = False
|
||||||
|
mock_config.QDRANT_REPLICATION_FACTOR = 1
|
||||||
|
mock_app.config = {"root_path": "/app"}
|
||||||
|
|
||||||
|
ds = self._make_dataset(index_struct_dict={"type": "tidb_on_qdrant", "vector_store": {"class_prefix": "col"}})
|
||||||
|
factory = TidbOnQdrantVectorFactory()
|
||||||
|
result = factory.init_vector(ds, [], MagicMock())
|
||||||
|
|
||||||
|
assert result._client_config.endpoint == "https://qdrant-global.tidb.com"
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
|
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
|
||||||
|
|
||||||
|
|
||||||
@ -118,3 +119,90 @@ class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
|
|||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateTidbServerlessClusterRetry:
|
||||||
|
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
|
||||||
|
|
||||||
|
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||||
|
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
|
||||||
|
mock_config.TIDB_SPEND_LIMIT = 10
|
||||||
|
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||||
|
mock_get_cluster.side_effect = [
|
||||||
|
{"state": "CREATING", "userPrefix": ""},
|
||||||
|
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||||
|
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
|
||||||
|
assert mock_get_cluster.call_count == 2
|
||||||
|
|
||||||
|
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||||
|
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
|
||||||
|
mock_config.TIDB_SPEND_LIMIT = 10
|
||||||
|
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||||
|
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
|
||||||
|
|
||||||
|
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||||
|
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||||
|
def test_raises_on_post_failure(self, mock_config, mock_http):
|
||||||
|
mock_config.TIDB_SPEND_LIMIT = 10
|
||||||
|
mock_response = MagicMock(status_code=400, text="Bad Request")
|
||||||
|
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
|
||||||
|
mock_http.post.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="HTTP 400"):
|
||||||
|
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchCreateEdgeCases:
|
||||||
|
"""Cover logging/edge-case branches in batch_create."""
|
||||||
|
|
||||||
|
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||||
|
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||||
|
mock_config.TIDB_SPEND_LIMIT = 10
|
||||||
|
mock_http.post.return_value = MagicMock(
|
||||||
|
status_code=200,
|
||||||
|
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
|
||||||
|
)
|
||||||
|
mock_redis.setex = MagicMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||||
|
batch_size=1, project_id="proj", api_url="url", iam_url="iam",
|
||||||
|
public_key="pub", private_key="priv", region="us-east-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
|
mock_fetch_ep.assert_not_called()
|
||||||
|
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||||
|
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||||
|
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
|
||||||
|
mock_config.TIDB_SPEND_LIMIT = 10
|
||||||
|
mock_response = MagicMock(status_code=500, text="Server Error")
|
||||||
|
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
|
||||||
|
mock_http.post.return_value = mock_response
|
||||||
|
mock_redis.setex = MagicMock()
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="HTTP 500"):
|
||||||
|
TidbService.batch_create_tidb_serverless_cluster(
|
||||||
|
batch_size=1, project_id="proj", api_url="url", iam_url="iam",
|
||||||
|
public_key="pub", private_key="priv", region="us-east-1",
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user