mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
feat(api): Making WeaviateClient a singleton
Co-authored-by: lijiezhao <lijiezhao@perfect99.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
29b724e23d
commit
3920d67b8e
@ -8,6 +8,7 @@ document embeddings used in retrieval-augmented generation workflows.
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
@ -32,6 +33,9 @@ from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
@ -99,43 +103,52 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
Configures both HTTP and gRPC connections with proper authentication.
|
||||
"""
|
||||
p = urlparse(config.endpoint)
|
||||
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
global _weaviate_client
|
||||
if _weaviate_client and _weaviate_client.is_ready():
|
||||
return _weaviate_client
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
with _weaviate_client_lock:
|
||||
if _weaviate_client and _weaviate_client.is_ready():
|
||||
return _weaviate_client
|
||||
|
||||
p = urlparse(config.endpoint)
|
||||
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
http_port=http_port,
|
||||
http_secure=http_secure,
|
||||
grpc_host=grpc_host,
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
http_port=http_port,
|
||||
http_secure=http_secure,
|
||||
grpc_host=grpc_host,
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||
)
|
||||
if not client.is_ready():
|
||||
raise ConnectionError("Vector database is not ready")
|
||||
|
||||
if not client.is_ready():
|
||||
raise ConnectionError("Vector database is not ready")
|
||||
|
||||
return client
|
||||
_weaviate_client = client
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Returns the vector database type identifier."""
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
|
||||
|
||||
def test_init_client_with_valid_config():
|
||||
"""Test successful client initialization with valid configuration."""
|
||||
config = WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||
)
|
||||
|
||||
with patch("weaviate.connect_to_custom") as mock_connect:
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_connect.return_value = mock_client
|
||||
|
||||
vector = WeaviateVector(
|
||||
collection_name="test_collection",
|
||||
config=config,
|
||||
attributes=["doc_id"],
|
||||
)
|
||||
|
||||
assert vector._client == mock_client
|
||||
mock_connect.assert_called_once()
|
||||
call_kwargs = mock_connect.call_args[1]
|
||||
assert call_kwargs["http_host"] == "localhost"
|
||||
assert call_kwargs["http_port"] == 8080
|
||||
assert call_kwargs["http_secure"] is False
|
||||
assert call_kwargs["grpc_host"] == "localhost"
|
||||
assert call_kwargs["grpc_port"] == 50051
|
||||
assert call_kwargs["grpc_secure"] is False
|
||||
assert call_kwargs["auth_credentials"] is not None
|
||||
Loading…
Reference in New Issue
Block a user