feat(api): Making WeaviateClient a singleton

Co-authored-by: lijiezhao <lijiezhao@perfect99.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Sage 2026-03-16 13:38:28 +08:00 committed by GitHub
parent 29b724e23d
commit 3920d67b8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 33 deletions

View File

@ -8,6 +8,7 @@ document embeddings used in retrieval-augmented generation workflows.
import datetime
import json
import logging
import threading
import uuid as _uuid
from typing import Any
from urllib.parse import urlparse
@ -32,6 +33,9 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
_weaviate_client: weaviate.WeaviateClient | None = None
_weaviate_client_lock = threading.Lock()
class WeaviateConfig(BaseModel):
"""
@ -99,43 +103,52 @@ class WeaviateVector(BaseVector):
Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
global _weaviate_client
if _weaviate_client and _weaviate_client.is_ready():
return _weaviate_client
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
with _weaviate_client_lock:
if _weaviate_client and _weaviate_client.is_ready():
return _weaviate_client
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client
_weaviate_client = client
return client
def get_type(self) -> str:
"""Returns the vector database type identifier."""

View File

@ -0,0 +1,33 @@
from unittest.mock import MagicMock, patch
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
def test_init_client_with_valid_config():
"""Test successful client initialization with valid configuration."""
config = WeaviateConfig(
endpoint="http://localhost:8080",
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
)
with patch("weaviate.connect_to_custom") as mock_connect:
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_connect.return_value = mock_client
vector = WeaviateVector(
collection_name="test_collection",
config=config,
attributes=["doc_id"],
)
assert vector._client == mock_client
mock_connect.assert_called_once()
call_kwargs = mock_connect.call_args[1]
assert call_kwargs["http_host"] == "localhost"
assert call_kwargs["http_port"] == 8080
assert call_kwargs["http_secure"] is False
assert call_kwargs["grpc_host"] == "localhost"
assert call_kwargs["grpc_port"] == 50051
assert call_kwargs["grpc_secure"] is False
assert call_kwargs["auth_credentials"] is not None